[BZOJ3196][Tyvj1730]二逼平衡树

[BZOJ3196][Tyvj1730]二逼平衡树

试题描述

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:

  1. 查询 \(k\) 在区间内的排名

  2. 查询区间内排名为 \(k\) 的值

  3. 修改某一位值上的数值

  4. 查询 \(k\) 在区间内的前驱(前驱定义为小于 \(x\),且最大的数)

  5. 查询 \(k\) 在区间内的后继(后继定义为大于 \(x\),且最小的数)

输入

第一行两个数 \(n,m\) 表示长度为 \(n\) 的有序序列和 \(m\) 个操作

第二行有 \(n\) 个数,表示有序序列

下面有 \(m\) 行,\(opt\) 表示操作标号

\(opt=1\) 则为操作 \(1\),之后有三个数 \(l,r,k\) 表示查询 \(k\) 在区间 \([l,r]\) 的排名

\(opt=2\) 则为操作 \(2\),之后有三个数 \(l,r,k\) 表示查询区间 \([l,r]\) 内排名为 \(k\) 的数

\(opt=3\) 则为操作 \(3\),之后有两个数 \(pos,k\) 表示将 \(pos\) 位置的数修改为 \(k\)

\(opt=4\) 则为操作 \(4\),之后有三个数 \(l,r,k\) 表示查询区间 \([l,r]\)\(k\) 的前驱

\(opt=5\) 则为操作 \(5\),之后有三个数 \(l,r,k\) 表示查询区间 \([l,r]\)\(k\) 的后继

输出

对于操作 \(1,2,4,5\) 各输出一行,表示查询结果

输入示例

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

输出示例

2
4
3
4
9

数据规模及约定

  1. \(n\)\(m\) 的数据范围:\(n,m \le 50000\)

  2. 序列中每个数的数据范围:\([0,10^8]\)

  3. 虽然原题没有,但事实上 \(5\) 操作的 \(k\) 可能为负数

题解

填个坑学了学 fhq treap,感觉挺好写的,还可以轻易地可持久化。就用这个树套树裸题练练手。(下面代码是 \(O(n \log^3 n)\) 的)

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
#define rep(i, s, t) for(int i = (s), mi = (t); i <= mi; i++)
#define dwn(i, s, t) for(int i = (s), mi = (t); i >= mi; i--)

int read() {
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
	return x * f;
}

#define maxn 50010
#define maxnode 2000010
#define oo 2147483647
#define pii pair <int, int>
#define x first
#define y second
#define mp(x, y) make_pair(x, y)

int n, A[maxn];

struct Node {
	int v, r, siz;
	Node() {}
	Node(int _): v(_), r(rand()), siz(1) {}
} ns[maxnode];
int ToT, rt[maxn<<2], ch[maxnode][2];

void maintain(int o) {
	if(!o) return ;
	ns[o].siz = 1;
	if(ch[o][0]) ns[o].siz += ns[ch[o][0]].siz;
	if(ch[o][1]) ns[o].siz += ns[ch[o][1]].siz;
	return ;
}
int merge(int a, int b) {
	if(!a) return maintain(b), b;
	if(!b) return maintain(a), a;
	if(ns[a].r > ns[b].r) return ch[a][1] = merge(ch[a][1], b), maintain(a), a;
	return ch[b][0] = merge(a, ch[b][0]), maintain(b), b;
}
pii split(int o, int v) {
	if(!o) return mp(0, 0);
	pii pr;
	if(v <= ns[o].v) {
		pr = split(ch[o][0], v);
		ch[o][0] = pr.y; maintain(o);
		return mp(pr.x, o);
	}
	pr = split(ch[o][1], v);
	ch[o][1] = pr.x; maintain(o);
	return mp(o, pr.y);
}
void Insert(int& o, int v) {
	pii pr = split(o, v);
	ns[++ToT] = Node(v);
	o = merge(pr.x, ToT);
	o = merge(o, pr.y);
	return ;
}
void Delete(int& o, int v) {
	if(!o) return ;
	if(v == ns[o].v) return (void)(o = merge(ch[o][0], ch[o][1]));
	if(v < ns[o].v) Delete(ch[o][0], v);
	else Delete(ch[o][1], v);
	return maintain(o);
}
int qrnk(int o, int v) {
	if(!o) return 0;
	int ls = ch[o][0] ? ns[ch[o][0]].siz : 0;
	if(v <= ns[o].v) return qrnk(ch[o][0], v);
	return ls + 1 + qrnk(ch[o][1], v);
}
int qpre(int o, int v) {
	if(!o) return -1;
	if(ns[o].v < v) return max(ns[o].v, qpre(ch[o][1], v));
	return qpre(ch[o][0], v);
}
int qnxt(int o, int v) {
	if(!o) return oo;
	if(ns[o].v > v) return min(ns[o].v, qnxt(ch[o][0], v));
	return qnxt(ch[o][1], v);
}

void build(int o, int l, int r) {
	rep(i, l, r) Insert(rt[o], A[i]);
	if(l == r) return ;
	int mid = l + r >> 1, lc = o << 1, rc = lc | 1;
	build(lc, l, mid); build(rc, mid + 1, r);
	return ;
}
int modify(int o, int l, int r, int p, int v) {
	if(l == r) {
		int tmp = ns[rt[o]].v;
		Delete(rt[o], tmp);
		Insert(rt[o], v);
		return tmp;
	}
	int mid = l + r >> 1, lc = o << 1, rc = lc | 1, tmp;
	if(p <= mid) Delete(rt[o], tmp = modify(lc, l, mid, p, v)), Insert(rt[o], v);
	else Delete(rt[o], tmp = modify(rc, mid + 1, r, p, v)), Insert(rt[o], v);
	return tmp;
}
int qsmaller(int o, int l, int r, int ql, int qr, int v) {
	if(ql <= l && r <= qr) return qrnk(rt[o], v);
	int mid = l + r >> 1, lc = o << 1, rc = lc | 1, ans = 0;
	if(ql <= mid) ans += qsmaller(lc, l, mid, ql, qr, v);
	if(qr > mid) ans += qsmaller(rc, mid + 1, r, ql, qr, v);
	return ans;
}
int qkth(int ql, int qr, int k) {
	int l = 0, r = (int)1e8 + 1;
	while(r - l > 1) {
		int mid = l + r >> 1;
		if(qsmaller(1, 1, n, ql, qr, mid) + 1 <= k) l = mid; else r = mid;
	}
	return l;
}
int askpre(int o, int l, int r, int ql, int qr, int v) {
	if(ql <= l && r <= qr) return qpre(rt[o], v);
	int mid = l + r >> 1, lc = o << 1, rc = lc | 1, ans = -1;
	if(ql <= mid) ans = max(ans, askpre(lc, l, mid, ql, qr, v));
	if(qr > mid) ans = max(ans, askpre(rc, mid + 1, r, ql, qr, v));
	return ans;
}
int asknxt(int o, int l, int r, int ql, int qr, int v) {
	if(ql <= l && r <= qr) return qnxt(rt[o], v);
	int mid = l + r >> 1, lc = o << 1, rc = lc | 1, ans = oo;
	if(ql <= mid) ans = min(ans, asknxt(lc, l, mid, ql, qr, v));
	if(qr > mid) ans = min(ans, asknxt(rc, mid + 1, r, ql, qr, v));
	return ans;
}

int main() {
	n = read(); int q = read();
	rep(i, 1, n) A[i] = read();
	
	build(1, 1, n);
	while(q--) {
		int tp = read(), l, r, k;
		if(tp == 1) l = read(), r = read(), k = read(), printf("%d\n", qsmaller(1, 1, n, l, r, k) + 1);
		if(tp == 2) l = read(), r = read(), k = read(), printf("%d\n", qkth(l, r, k));
		if(tp == 3) l = read(), k = read(), modify(1, 1, n, l, k);
		if(tp == 4) l = read(), r = read(), k = read(), printf("%d\n", askpre(1, 1, n, l, r, k));
		if(tp == 5) l = read(), r = read(), k = read(), printf("%d\n", asknxt(1, 1, n, l, r, k));
	}
	
	return 0;
}
posted @ 2018-03-04 21:29  xjr01  阅读(275)  评论(0编辑  收藏  举报