洛谷U644824 简单的平衡树问题 题解 splay tree 模板题
题目链接:https://www.luogu.com.cn/problem/U644824
示例程序:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 4e5 + 5;
struct Node {
int s[2], p, v, sz;
void init(int _v, int _p) {
s[0] = s[1] = 0;
v = _v;
p = _p;
sz = 1;
}
} tr[maxn];
int root, idx, n, Q, a[maxn];
void push_up(int u) {
int l = tr[u].s[0], r = tr[u].s[1];
tr[u].sz = tr[l].sz + tr[r].sz + 1;
}
void f_s(int p, int u, int k) {
tr[p].s[k] = u;
tr[u].p = p;
}
void rot(int x) {
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
f_s(z, x, tr[z].s[1] == y);
f_s(y, tr[x].s[k^1], k);
f_s(x, y, k^1);
push_up(y), push_up(x);
}
void splay(int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
(tr[y].s[1]==x)^(tr[z].s[1]==y) ? rot(x) : rot(y);
rot(x);
}
if (!k) root = x;
}
int get_k(int k) {
int u = root;
while (u) {
if (tr[tr[u].s[0]].sz >= k) u = tr[u].s[0];
else if (tr[tr[u].s[0]].sz + 1 == k) return tr[u].v;
else k -= tr[tr[u].s[0]].sz + 1, u = tr[u].s[1];
}
return -1;
}
int get_rnk(int x) {
int cnt = 1, u = root;
while (u) {
if (tr[u].v < x) {
cnt += 1 + tr[ tr[u].s[0] ].sz;
u = tr[u].s[1];
}
else
u = tr[u].s[0];
}
return cnt;
}
void ins(int v) {
int u = root, p = 0, k = 0;
while (u) {
tr[u].sz++;
k = tr[u].v < v;
p = u;
u = tr[u].s[k];
}
u = ++idx;
tr[u].init(v, p);
if (p) tr[p].s[k] = u;
splay(u, 0);
}
void del(int v) {
int u = root;
while (u) {
if (tr[u].v == v) break;
else u = tr[u].s[tr[u].v < v];
}
splay(u, 0);
int l = tr[u].s[0], r = tr[u].s[1];
if (!l || !r) {
root = l + r;
tr[root].p = 0;
}
else {
while (tr[l].s[1]) l = tr[l].s[1];
while (tr[r].s[0]) r = tr[r].s[0];
splay(l, 0);
splay(r, l);
tr[r].s[0] = 0;
push_up(r);
push_up(l);
}
}
int get_pre(int x) {
int ans = -1, u = root;
while (u) {
if (tr[u].v < x)
ans = tr[u].v, u = tr[u].s[1];
else
u = tr[u].s[0];
}
return ans;
}
int get_suc(int x) {
int ans = -1, u = root;
while (u) {
if (tr[u].v > x)
ans = tr[u].v, u = tr[u].s[0];
else
u = tr[u].s[1];
}
return ans;
}
int main() {
scanf("%d%d", &n, &Q);
for (int i = 1; i <= n; i++) {
scanf("%d", a+i);
ins(a[i]);
}
while (Q--) {
int op, p, k, x;
scanf("%d", &op);
if (op == 1) { // 1 p x
scanf("%d%d", &p, &x);
del(a[p]);
a[p] = x;
ins(a[p]);
}
else if (op == 2) { // 2 x
scanf("%d", &x);
printf("%d\n", get_rnk(x));
}
else if (op == 3) { // 3 k
scanf("%d", &k);
printf("%d\n", get_k(k));
}
else if (op == 4) { // 4 x
scanf("%d", &x);
printf("%d\n", get_pre(x));
}
else { // 5 x
scanf("%d", &x);
printf("%d\n", get_suc(x));
}
}
return 0;
}
浙公网安备 33010602011771号