[BZOJ3224]Tyvj 1728 普通平衡树
[BZOJ3224]Tyvj 1728 普通平衡树
试题描述
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)
输入
第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)
输出
对于操作3,4,5,6每行输出一个数,表示对应答案
输入示例
10 1 106465 4 1 1 317721 1 460929 1 644985 1 84185 1 89851 6 81968 1 492737 5 493598
输出示例
106465 84185 492737
数据规模及约定
1.n的数据范围:n<=100000
2.每个数的数据范围:[-1e7,1e7]
题解
treap 模板题。
#include <iostream> #include <cstdio> #include <cstring> #include <cstdlib> #include <cctype> #include <algorithm> using namespace std; int read() { int x = 0, f = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); } return x * f; } #define maxn 100010 struct Node { int v, r, siz; Node() {} Node(int _, int __): v(_), r(__) {} } ns[maxn]; int rt, ToT, fa[maxn], ch[2][maxn]; void maintain(int o) { ns[o].siz = 1; for(int i = 0; i < 2; i++) if(ch[i][o]) ns[o].siz += ns[ch[i][o]].siz; return ; } void rotate(int u) { int y = fa[u], z = fa[y], l = 0, r = 1; if(z) ch[ch[1][z]==y][z] = u; if(ch[1][y] == u) swap(l, r); fa[u] = z; fa[y] = u; fa[ch[r][u]] = y; ch[l][y] = ch[r][u]; ch[r][u] = y; maintain(y); maintain(u); return ; } void insert(int& o, int v) { if(!o) { ns[o = ++ToT] = Node(v, rand()); return maintain(o); } bool d = v > ns[o].v; insert(ch[d][o], v); fa[ch[d][o]] = o; if(ns[ch[d][o]].r > ns[o].r) { int t = ch[d][o]; rotate(t); o = t; } return maintain(o); } void del(int& o, int v) { if(!o) return ; if(ns[o].v == v) { if(!ch[0][o] && !ch[1][o]) o = 0; else if(!ch[0][o]) { int t = ch[1][o]; fa[t] = fa[o]; o = t; } else if(!ch[1][o]) { int t = ch[0][o]; fa[t] = fa[o]; o = t; } else { bool d = ns[ch[1][o]].r > ns[ch[0][o]].r; int t = ch[d][o]; rotate(t); o = t; del(ch[d^1][o], v); } } else { bool d = v > ns[o].v ; del(ch[d][o], v); } return maintain(o); } int qrank(int o, int v) { if(!o) return 0; int ls = ch[0][o] ? ns[ch[0][o]].siz : 0; if(v > ns[o].v) return ls + 1 + qrank(ch[1][o], v); return qrank(ch[0][o], v); } #define err -233333333 #define errm 233333333 int qkth(int o, int k) { if(!o) return err; int ls = ch[0][o] ? ns[ch[0][o]].siz : 0; if(k == ls + 1) return ns[o].v; if(k > ls + 1) return qkth(ch[1][o], k - ls - 1); return qkth(ch[0][o], k); } int qlow(int o, int v) { if(!o) return err; bool d = v > ns[o].v; if(d) return max(ns[o].v, qlow(ch[d][o], v)); return qlow(ch[d][o], v); } int qupp(int o, int v) { if(!o) return errm; bool d = v >= ns[o].v; if(!d) return min(ns[o].v, qupp(ch[d][o], v)); return qupp(ch[d][o], v); } int main() { int q = read(); while(q--) { int tp = read(), v = read(); if(tp == 1) insert(rt, v); if(tp == 2) del(rt, v); if(tp == 3) printf("%d\n", qrank(rt, v) + 1); if(tp == 4) printf("%d\n", qkth(rt, v)); if(tp == 5) printf("%d\n", qlow(rt, v)); if(tp == 6) printf("%d\n", qupp(rt, v)); } return 0; }
再贴一个替罪羊树版本的。
#include <iostream> #include <cstdio> #include <algorithm> #include <cmath> #include <stack> #include <vector> #include <queue> #include <cstring> #include <string> #include <map> #include <set> using namespace std; const int BufferSize = 1 << 16; char buffer[BufferSize], *Head, *Tail; inline char Getchar() { if(Head == Tail) { int l = fread(buffer, 1, BufferSize, stdin); Tail = (Head = buffer) + l; } return *Head++; } int read() { int x = 0, f = 1; char c = Getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); } return x * f; } #define maxn 100010 #define oo 2147483647 struct Node { int v, siz, reas, mx, mn; bool del; Node() {} Node(int _): v(_), del(0) {} } ns[maxn]; int rt, ToT, fa[maxn], ch[maxn][2]; void maintain(int o) { ns[o].siz = ns[o].del ^ 1; ns[o].reas = 1; ns[o].mx = ns[o].del ? -oo : ns[o].v; ns[o].mn = ns[o].del ? oo : ns[o].v; for(int i = 0; i < 2; i++) if(ch[o][i]) ns[o].siz += ns[ch[o][i]].siz, ns[o].reas += ns[ch[o][i]].reas, ns[o].mx = max(ns[o].mx, ns[ch[o][i]].mx), ns[o].mn = min(ns[o].mn, ns[ch[o][i]].mn); return ; } const double Bili = .6; bool unbal(int o) { return max(ch[o][0] ? ns[ch[o][0]].reas : 0, ch[o][1] ? ns[ch[o][1]].reas : 0) > Bili * ns[o].reas; } int rb; void insert(int& o, int v) { if(!o) { ns[o = ++ToT] = Node(v); return maintain(o); } bool d = v > ns[o].v; insert(ch[o][d], v); fa[ch[o][d]] = o; maintain(o); if(unbal(o)) rb = o; return ; } int cntn, get[maxn]; void getnode(int o) { if(!o) return ; getnode(ch[o][0]); if(!ns[o].del) get[++cntn] = o; getnode(ch[o][1]); fa[o] = ch[o][0] = ch[o][1] = 0; return ; } void build(int& o, int l, int r) { if(l > r){ o = 0; return ; } int mid = l + r >> 1; o = get[mid]; build(ch[o][0], l, mid - 1); build(ch[o][1], mid + 1, r); if(ch[o][0]) fa[ch[o][0]] = o; if(ch[o][1]) fa[ch[o][1]] = o; return maintain(o); } void rebuild(int& o) { cntn = 0; getnode(o); build(o, 1, cntn); return ; } void Insert(int v) { rb = 0; insert(rt, v); if(!rb) return ; int frb = fa[rb]; if(!frb) rebuild(rt), fa[rt] = 0; else if(ch[frb][0] == rb) rebuild(ch[frb][0]), fa[ch[frb][0]] = frb; else rebuild(ch[frb][1]), fa[ch[frb][1]] = frb; return ; } bool unbal2(int o) { return Bili * ns[o].reas > ns[o].siz; } void del(int o, int k) { if(!o) return ; int ls = ch[o][0] ? ns[ch[o][0]].siz : 0; if(k == ls + 1 && !ns[o].del) { ns[o].del = 1; maintain(o); if(unbal2(o)) rb = o; return ; } if(k > ls + (ns[o].del ^ 1)) { del(ch[o][1], k - ls - (ns[o].del ^ 1)); maintain(o); if(unbal2(o)) rb = o; return ; } del(ch[o][0], k); maintain(o); if(unbal2(o)) rb = o; return ; } int Find(int o, int v) { if(!o) return 0; int ls = ch[o][0] ? ns[ch[o][0]].siz : 0; if(v <= ns[o].v) return Find(ch[o][0], v); return ls + (ns[o].del ^ 1) + Find(ch[o][1], v); } void Delete(int v) { int id = Find(rt, v) + 1; rb = 0; del(rt, id); int tmp = ns[rt].reas; if(!rb) return ; int frb = fa[rb]; if(!frb) rebuild(rt), fa[rt] = 0; else if(ch[frb][0] == rb) rebuild(ch[frb][0]), fa[ch[frb][0]] = frb; else rebuild(ch[frb][1]), fa[ch[frb][1]] = frb; return ; } int qkth(int o, int k) { if(!o) return 0; int ls = ch[o][0] ? ns[ch[o][0]].siz : 0; if(k == ls + 1 && !ns[o].del) return ns[o].v; if(k > ls + (ns[o].del ^ 1)) return qkth(ch[o][1], k - ls - (ns[o].del ^ 1)); return qkth(ch[o][0], k); } int qlow(int o, int v) { if(!o) return -oo; if(ns[o].v < v) return max(max(ns[o].del ? -oo : ns[o].v, ch[o][0] ? ns[ch[o][0]].mx : -oo), qlow(ch[o][1], v)); return qlow(ch[o][0], v); } int qupp(int o, int v) { if(!o) return oo; if(ns[o].v > v) return min(min(ns[o].del ? oo : ns[o].v, ch[o][1] ? ns[ch[o][1]].mn : oo), qupp(ch[o][0], v)); return qupp(ch[o][1], v); } int main() { int q = read(); while(q--) { int tp = read(), x = read(); if(tp == 1) Insert(x); if(tp == 2) Delete(x); if(tp == 3) printf("%d\n", Find(rt, x) + 1); if(tp == 4) printf("%d\n", qkth(rt, x)); if(tp == 5) printf("%d\n", qlow(rt, x)); if(tp == 6) printf("%d\n", qupp(rt, x)); } return 0; }
懒惰删除真的一点都不懒惰!!!打了删除标记反而更难处理了。。。太多细节要考虑。
再贴一个 fhq treap 的版本。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> #include <cassert> using namespace std; #define rep(i, s, t) for(int i = (s), mi = (t); i <= mi; i++) #define dwn(i, s, t) for(int i = (s), mi = (t); i >= mi; i--) const int BufferSize = 1 << 16; char buffer[BufferSize], *Head, *Tail; inline char Getchar() { if(Head == Tail) { int l = fread(buffer, 1, BufferSize, stdin); Tail = (Head = buffer) + l; } return *Head++; } int read() { int x = 0, f = 1; char c = Getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = Getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = Getchar(); } return x * f; } #define maxn 100010 #define oo 2147483647 #define pii pair <int, int> #define x first #define y second #define mp(x, y) make_pair(x, y) struct Node { int v, r, siz; Node() {} Node(int _v): v(_v), r(rand()), siz(1) {} } ns[maxn]; int ToT, rt, ch[maxn][2]; void maintain(int o) { if(!o) return ; ns[o].siz = 1; if(ch[o][0]) ns[o].siz += ns[ch[o][0]].siz; if(ch[o][1]) ns[o].siz += ns[ch[o][1]].siz; return ; } int merge(int a, int b) { // max{a} <= min{b} if(!a) return maintain(b), b; if(!b) return maintain(a), a; if(ns[a].r > ns[b].r) return ch[a][1] = merge(ch[a][1], b), maintain(a), a; return ch[b][0] = merge(a, ch[b][0]), maintain(b), b; } pii split(int o, int k) { if(!o) return mp(0, 0); if(!k) return mp(0, o); int ls = ch[o][0] ? ns[ch[o][0]].siz : 0; if(k <= ls) { pii pr = split(ch[o][0], k); ch[o][0] = pr.y; maintain(o); return mp(pr.x, o); } pii pr = split(ch[o][1], k - ls - 1); ch[o][1] = pr.x; maintain(o); return mp(o, pr.y); } int Find(int o, int v) { if(!o) return 0; int ls = ch[o][0] ? ns[ch[o][0]].siz : 0; if(v <= ns[o].v) return Find(ch[o][0], v); return ls + 1 + Find(ch[o][1], v); } void Insert(int v) { int rnk = Find(rt, v); pii pr = split(rt, rnk); ns[++ToT] = Node(v); rt = merge(pr.x, ToT); rt = merge(rt, pr.y); return ; } void Delete(int& o, int v) { if(!o) return ; if(ns[o].v == v) return (void)(o = merge(ch[o][0], ch[o][1])); if(v < ns[o].v) Delete(ch[o][0], v); else Delete(ch[o][1], v); return maintain(o); } int qkth(int o, int k) { assert(o); int ls = ch[o][0] ? ns[ch[o][0]].siz : 0; if(k == ls + 1) return ns[o].v; if(k <= ls) return qkth(ch[o][0], k); return qkth(ch[o][1], k - ls - 1); } int qpre(int o, int v) { if(!o) return -oo; if(ns[o].v < v) return max(ns[o].v, qpre(ch[o][1], v)); return qpre(ch[o][0], v); } int qnxt(int o, int v) { if(!o) return oo; if(ns[o].v > v) return min(ns[o].v, qnxt(ch[o][0], v)); return qnxt(ch[o][1], v); } int main () { srand(5); int q = read(); while(q--) { int tp = read(); if(tp == 1) Insert(read()); if(tp == 2) Delete(rt, read()); if(tp == 3) printf("%d\n", Find(rt, read()) + 1); if(tp == 4) printf("%d\n", qkth(rt, read())); if(tp == 5) printf("%d\n", qpre(rt, read())); if(tp == 6) printf("%d\n", qnxt(rt, read())); } return 0; }