【笔记】Splay Tree

挖坑
贴代码

#include <cstdio>
using namespace std;
const int MAXN = 100005;
const int sup = 0x3f3f3f3f;
const int inf = -sup;
struct splayTree {
	#define chk(x) (x==ch[fa[x]][1])
	#define update(x) sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + cnt[x]
	int root, tot;
	int ch[MAXN][2], fa[MAXN], val[MAXN], cnt[MAXN], sz[MAXN];
	splayTree() { root = tot = 0; }
	void rotate(int x) {
		int y = fa[x], z = fa[y], k = chk(x), w = ch[x][k^1];
		fa[x] = z, ch[z][chk(y)] = x;
		fa[y] = x, ch[x][k^1] = y;
		fa[w] = y, ch[y][k] = w;
		update(y), update(x);
	}
	void splay(int x, int goal=0) {
		while (fa[x]!=goal) {
			int y = fa[x], z = fa[y];
			if (z!=goal) {
				if (chk(x)==chk(y)) rotate(y);
				else rotate(x);
			}
			rotate(x);
		}
		if (!goal) root = x;
	}
	void insert(int x) {
		int cur = root, p = 0; // p is cur's fa
		while (cur && x!=val[cur]) p = cur, cur = ch[cur][val[cur]< x];
		if (cur) ++cnt[cur];
		else {
			cur = ++tot;
			if (p) ch[p][val[p]< x] = cur;
			cnt[cur] = sz[cur] = 1;
			fa[cur] = p, val[cur] = x;
			ch[cur][0] = ch[cur][1] = 0;
		}
		splay(cur);
	}
	void get_rank(int x) {
		int cur = root;
		while (x!=val[cur] && ch[cur][val[cur]< x]) cur = ch[cur][val[cur]< x];
		splay(cur);
	}
	int get_by_rank(int k) {
		int cur = root;
		for (;;) {
			if (ch[cur][0] && sz[ch[cur][0]]>=k) cur = ch[cur][0];
			else if (ch[cur][1] && sz[ch[cur][0]]+cnt[cur]< k) {
				k -= sz[ch[cur][0]]+cnt[cur], cur = ch[cur][1];
			} else return cur;
		}
	}
	int pre(int x) { // return cur, not val[cur]
		get_rank(x);
		if (x> val[root]) return root;
		int cur = ch[root][0];
		while (ch[cur][1]) cur = ch[cur][1];
		return cur;
	}
	int nxt(int x) {
		get_rank(x);
		if (x< val[root]) return root;
		int cur = ch[root][1];
		while (ch[cur][0]) cur = ch[cur][0];
		return cur;
	}
	void remove(int x) {
		int prev = pre(x), next = nxt(x);
		splay(prev), splay(next, prev);
		int del = ch[next][0];
		if (cnt[del]> 1) --cnt[del], splay(del);
		else ch[next][0] = 0;
	}
} ST;
int T;
int main()
{
	ST.insert(inf), ST.insert(sup);
	for (scanf("%d", &T); T; T--) {
		int opt, x; scanf("%d%d", &opt, &x);
		switch (opt) {
			case 1: ST.insert(x); break;
			case 2: ST.remove(x); break;
			case 3: ST.get_rank(x); printf("%d\n", ST.sz[ST.ch[ST.root][0]]); break;
			case 4: printf("%d\n", ST.val[ST.get_by_rank(x+1)]); break;
			case 5: printf("%d\n", ST.val[ST.pre(x)]); break;
			case 6: printf("%d\n", ST.val[ST.nxt(x)]); break;
		}
	}
}
posted @ 2021-08-19 18:26  zrkc  阅读(47)  评论(0)    收藏  举报