BZOJ 3224 普通平衡树 | 平衡树模板

#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#define space putchar(' ')
#define enter putchar('\n')
using namespace std;
typedef long long ll;
template <class T>
void read(T &x){
    char c;
    bool op = 0;
    while(c = getchar(), c < '0' || c > '9')
	if(c == '-') op = 1;
    x = c - '0';
    while(c = getchar(), c >= '0' && c <= '9')
	x = x * 10 + c - '0';
    if(op) x = -x;
}
template <class T>
void write(T x){
    if(x < 0) putchar('-'), x = -x;
    if(x >= 10) write(x / 10);
    putchar('0' + x % 10);
}
//欢迎阅读胡小兔的平衡树板子 =v=
const int N = 100005;
int n, root, idx, val[N], fa[N], ls[N], rs[N], sze[N], cnt[N];
#define which(x) (ls[fa[(x)]] == (x)) //判断x的"方向": x是左儿子还是右儿子

void upt(int x){ //update: 更新sze[x]
    sze[x] = sze[ls[x]] + sze[rs[x]] + cnt[x];
}
void rotate(int x){ //如果x是左儿子则右旋,右儿子则左旋
    int y = fa[x], z = fa[y], b = which(x) ? rs[x] : ls[x], dir = which(y);
    which(x) ? (rs[x] = y, ls[y] = b) : (ls[x] = y, rs[y] = b);
    fa[y] = x, fa[b] = y, fa[x] = z;
    if(z) dir ? ls[z] = x : rs[z] = x;
    upt(y), upt(x); //记得旋转之后更新大小,由下往上更新,此时y在下而x在上
}
void splay(int x){//将x旋转至根节点
    while(fa[x]){//原则:为了尽可能使树平衡,如果x和fa[x]方向相同则先旋转fa再旋转x,否则旋转两次x
	if(fa[fa[x]]){
	    if(which(x) == which(fa[x])) rotate(fa[x]);
	    else rotate(x);
	}
	rotate(x);
    }
    root = x; //记得更新根节点
}
int find(int x){ //找到值为x的节点; 如果没有则返回
    int cur = root, last = 0;
    while(cur && val[cur] != x){
	last = cur;
	if(x < val[cur]) cur = ls[cur];
	else cur = rs[cur];
    }
    return cur ? cur : last;
}
int getmin(int x){ //找子树x中最小的点的编号
    while(ls[x]) x = ls[x];
    return x;
}
int getmax(int x){ //找子树x中最大的点的编号
    while(rs[x]) x = rs[x];
    return x;
}
void insert(int x){ //插入一个数
    int cur = find(x); //找到值最相近的节点的编号
    if(cur && val[cur] == x) return (void)(cnt[cur]++, sze[cur]++, splay(cur)); //如果已存在这个节点,则cnt++
    val[++idx] = x, fa[idx] = cur, cnt[idx] = sze[idx] = 1;// 如果不存在这个节点,则新增一个节点
    if(cur) x < val[cur] ? ls[cur] = idx : rs[cur] = idx;
    splay(idx);
}
void erase(int x){
    int cur = find(x);
    splay(cur);
    if(cnt[cur] > 1) cnt[cur]--, sze[cur]--; //如果这个值去掉一个之后还存在,则只要cnt--就好了
    else if(!ls[cur] || !rs[cur]) root = ls[cur] + rs[cur], fa[root] = 0; //如果至少一个儿子为空,则让那个儿子做根节点;如果两个儿子均为空,则说明删除这个点后整棵树为空
    else{
	fa[ls[cur]] = 0; //让左子树中最大的点做根节点,右子树做新根节点的右子树
	int u = getmax(ls[cur]);
	splay(u);
	rs[u] = rs[cur], fa[rs[cur]] = u;
	upt(u);
    }
}
int getkth(int k){ //找排名为k的数,类似权值线段树
    int cur = root;
    while(cur){
	if(sze[ls[cur]] >= k) cur = ls[cur];
	else if(sze[ls[cur]] + cnt[cur] >= k) return val[cur];
	else k -= sze[ls[cur]] + cnt[cur], cur = rs[cur];
    }
    return val[cur];
}
int getrank(int x){ //求x的排名
    int cur = find(x);
    splay(cur);
    return sze[ls[cur]] + 1;
}
int getpre(int x){
    int cur = find(x);
    if(val[cur] < x) return val[cur];
    splay(cur);
    return val[getmax(ls[cur])];
}
int getnxt(int x){
    int cur = find(x);
    if(val[cur] > x) return val[cur];
    splay(cur);
    return val[getmin(rs[cur])];
}
int main(){
    read(n);
    while(n--){
	int op, x;
	read(op), read(x);
	if(op == 1) insert(x);
	if(op == 2) erase(x);
	if(op == 3) write(getrank(x)), enter;
	if(op == 4) write(getkth(x)), enter;
	if(op == 5) write(getpre(x)), enter;
	if(op == 6) write(getnxt(x)), enter;
    }
    return 0;
}
posted @ 2017-12-02 15:22  胡小兔  阅读(265)  评论(0编辑  收藏  举报