【笔记】Splay Tree
挖坑
贴代码
#include <cstdio>
using namespace std;
const int MAXN = 100005;
const int sup = 0x3f3f3f3f;
const int inf = -sup;
struct splayTree {
#define chk(x) (x==ch[fa[x]][1])
#define update(x) sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + cnt[x]
int root, tot;
int ch[MAXN][2], fa[MAXN], val[MAXN], cnt[MAXN], sz[MAXN];
splayTree() { root = tot = 0; }
void rotate(int x) {
int y = fa[x], z = fa[y], k = chk(x), w = ch[x][k^1];
fa[x] = z, ch[z][chk(y)] = x;
fa[y] = x, ch[x][k^1] = y;
fa[w] = y, ch[y][k] = w;
update(y), update(x);
}
void splay(int x, int goal=0) {
while (fa[x]!=goal) {
int y = fa[x], z = fa[y];
if (z!=goal) {
if (chk(x)==chk(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
if (!goal) root = x;
}
void insert(int x) {
int cur = root, p = 0; // p is cur's fa
while (cur && x!=val[cur]) p = cur, cur = ch[cur][val[cur]< x];
if (cur) ++cnt[cur];
else {
cur = ++tot;
if (p) ch[p][val[p]< x] = cur;
cnt[cur] = sz[cur] = 1;
fa[cur] = p, val[cur] = x;
ch[cur][0] = ch[cur][1] = 0;
}
splay(cur);
}
void get_rank(int x) {
int cur = root;
while (x!=val[cur] && ch[cur][val[cur]< x]) cur = ch[cur][val[cur]< x];
splay(cur);
}
int get_by_rank(int k) {
int cur = root;
for (;;) {
if (ch[cur][0] && sz[ch[cur][0]]>=k) cur = ch[cur][0];
else if (ch[cur][1] && sz[ch[cur][0]]+cnt[cur]< k) {
k -= sz[ch[cur][0]]+cnt[cur], cur = ch[cur][1];
} else return cur;
}
}
int pre(int x) { // return cur, not val[cur]
get_rank(x);
if (x> val[root]) return root;
int cur = ch[root][0];
while (ch[cur][1]) cur = ch[cur][1];
return cur;
}
int nxt(int x) {
get_rank(x);
if (x< val[root]) return root;
int cur = ch[root][1];
while (ch[cur][0]) cur = ch[cur][0];
return cur;
}
void remove(int x) {
int prev = pre(x), next = nxt(x);
splay(prev), splay(next, prev);
int del = ch[next][0];
if (cnt[del]> 1) --cnt[del], splay(del);
else ch[next][0] = 0;
}
} ST;
int T;
int main()
{
ST.insert(inf), ST.insert(sup);
for (scanf("%d", &T); T; T--) {
int opt, x; scanf("%d%d", &opt, &x);
switch (opt) {
case 1: ST.insert(x); break;
case 2: ST.remove(x); break;
case 3: ST.get_rank(x); printf("%d\n", ST.sz[ST.ch[ST.root][0]]); break;
case 4: printf("%d\n", ST.val[ST.get_by_rank(x+1)]); break;
case 5: printf("%d\n", ST.val[ST.pre(x)]); break;
case 6: printf("%d\n", ST.val[ST.nxt(x)]); break;
}
}
}