算法学习笔记(15): Splay树

Splay树

Splay树又名伸展树, 是tarjan为LCT而发明的平衡树, 通过旋转操作维护二叉搜索树的高度平衡, 其实不管时间复杂度的证明, Splay树挺简单的。 均摊复杂度 \(O(logn)\)(需要用到势能分析), 可以区间操作, 不能可持久化, 常数较大(大于FHQtreap), 但是可以 \(O(nlogn)\) 实现 LCT。(这是唯一比FHQtreap优秀的店...)

算法

splay树通常是把需要操作的点旋转到根, 这样就可以进行 \(O(1)\) 的操作, 同时旋转时要注意不能无脑转, 要兼顾高度平衡。

rotate操作

分为左旋和右旋, 目的是将节点 \(x\) 旋转到父亲 \(f\) 的位置。 (图例搬的OI-wiki)
image

splay操作

这是将某个节点旋转到根节点处的一系列操作, 我们将其统称为splay操作。 也就是说会经历若干个左旋和右旋操作, 从而将 \(x\) 旋转到根。
splay操作需要分类讨论:
此处请见OI-wiki splay分类讨论 (太懒了)
之所以我们要这样分类讨论, 是因为遵循这些规则转的话可以使高度较为平衡, 接近于 \(O(logn)\), 从而达到均摊复杂度 \(O(logn)\)

模板代码

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
struct Splay{
	int rt, tot, fa[N], ch[N][2], val[N], cnt[N], siz[N];
	void maintain(int u) { siz[u] = siz[ch[u][0]] + siz[ch[u][1]] + cnt[u]; }
	bool get(int u) { return u == ch[fa[u]][1]; }
	void clear(int u) { ch[u][0] = ch[u][1] = fa[u] = val[u] = cnt[u] = siz[u] = 0; }
	void rotate(int u) {
		int f = fa[u], gf = fa[f], chk = get(u);
		ch[f][chk] = ch[u][chk ^ 1];
		if (ch[u][chk ^ 1]) fa[ch[u][chk ^ 1]] = f;
		ch[u][chk ^ 1] = f;
		fa[f] = u;
		fa[u] = gf;
		if (gf) ch[gf][f == ch[gf][1]] = u;
		maintain(u);
		maintain(f);
	} 
	void splay(int u) {
		for (int f = fa[u]; f = fa[u], f; rotate(u)) 
			if (fa[f]) rotate(get(u) == get(f) ? f : u);
		rt = u;
	}
	void ins(int k) {
		if (!rt) {
			val[++tot] = k, cnt[tot]++;
			rt = tot;
			maintain(rt);
			return;
		}
		int cur = rt, f = 0;
		while (1) {
			if (val[cur] == k) {
				cnt[cur]++;
				maintain(cur);
				maintain(f);
				splay(cur);
				break;
			}
			f = cur;
			cur = ch[cur][val[cur] < k];
			if (!cur) {
				val[++tot] = k;
				cnt[tot]++;
				fa[tot] = f;
				ch[f][val[f] < k] = tot;
				maintain(tot);
				maintain(f);
				splay(tot);
				break;
			}
		}
	}
	int rk(int k) {
		int res = 0, cur = rt;
		while (1) {
			if (k < val[cur]) cur = ch[cur][0];
			else {
				res += siz[ch[cur][0]];
				if (!cur) return res + 1;
				if (k == val[cur]) {
					splay(cur);
					return res + 1;
				}
				res += cnt[cur];
				cur = ch[cur][1];
			}
		}
	}
	int kth(int k) {
		int cur = rt;
		while (1) {
			if (ch[cur][0] && k <= siz[ch[cur][0]]) cur = ch[cur][0];
			else {
				k -= cnt[cur] + siz[ch[cur][0]];
				if (k <= 0) {
					splay(cur);
					return val[cur];	
				}
				cur = ch[cur][1];
			}
		}
	} 
	int pre() {
		int cur = ch[rt][0];
		if (!cur) return cur;
		while (ch[cur][1]) cur = ch[cur][1];
		splay(cur);
		return cur;
	}
	int suf() {
		int cur = ch[rt][1];
		if (!cur) return cur;
		while (ch[cur][0]) cur = ch[cur][0];
		splay(cur);
		return cur;
	}
	void del(int k) {
		rk(k);
		if (cnt[rt] > 1) {
			cnt[rt]--;
			maintain(rt);
			return;
		}
		if (!ch[rt][0] && !ch[rt][1]) {
			clear(rt);
			rt = 0;
			return;
		}
		if (!ch[rt][0]) {
			int cur = rt;
			rt = ch[rt][1];
			fa[rt] = 0;
			clear(cur);
			return;
		}
		if (!ch[rt][1]) {
			int cur = rt;
			rt = ch[rt][0];
			fa[rt] = 0;
			clear(cur);
			return;
		}
		int cur = rt;
		int x = pre();
		fa[ch[cur][1]] = x;
		ch[x][1] = ch[cur][1];
		clear(cur);
		maintain(rt);
	}
}T;
int n;
int main() {
	scanf("%d", &n);
	for (int i = 1, op, x; i <= n; i++) {
		scanf("%d%d", &op, &x);
		switch(op) {
			case 1: T.ins(x); break;
			case 2: T.del(x); break;
			case 3: printf("%d\n", T.rk(x)); break;
			case 4: printf("%d\n", T.kth(x)); break;
			case 5: T.ins(x); printf("%d\n", T.val[T.pre()]); T.del(x); break;
			case 6: T.ins(x); printf("%d\n", T.val[T.suf()]); T.del(x); break;
		} 
	}
	return 0;
}

区间操作

区间翻转

可以打一个懒标记, 至于怎么打, 就是把代表区间 \((l, r)\) 的子树提出来, 然后打标记。 类似FHQ的, 我们需要按照数组下标为BST的键值建树。
方法: 将 \(l - 1\) 旋转到根, 将 \(r + 1\) 旋转到 \(l - 1\) 的右儿子, 这样区间 \((l, r)\) 就都在 \(r + 1\) 的左子树里, 直接打标记就可以了。

模板代码

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
const int INF = 1e9;
int a[N];
struct Splay{
	int tot, rt, cnt[N], siz[N], val[N], ch[N][2], fa[N], tag[N];
	void maintain(int u) { siz[u] = siz[ch[u][0]] + siz[ch[u][1]] + cnt[u]; }
	void clear(int u) { cnt[u] = siz[u] = val[u] = ch[u][0] = ch[u][1] = fa[u] = tag[u] = 0; }
	bool get(int u) { return u == ch[fa[u]][1]; }
	void pushdown(int u) {
		if (u && tag[u]) {
			tag[ch[u][0]] ^= 1;
			tag[ch[u][1]] ^= 1;
			swap(ch[u][0], ch[u][1]);
			tag[u] = 0; 
		}
	}  
	void build(int &p, int l, int r, int pre) {
		if (l > r) return;
		p = ++tot; int mid = l + r >> 1;
		cnt[tot]++, val[tot] = a[mid], fa[tot] = pre;
		build(ch[p][0], l, mid - 1, p);
		build(ch[p][1], mid + 1, r, p);
		maintain(p);
	}
	void rotate(int u) {
		int f = fa[u], gf = fa[f], chk = get(u);
		pushdown(f); pushdown(u);
		ch[f][chk] = ch[u][chk ^ 1];
		if (ch[u][chk ^ 1]) fa[ch[u][chk ^ 1]] = f;
		ch[u][chk ^ 1] = f;
		fa[f] = u;
		fa[u] = gf;
		if (gf) ch[gf][f == ch[gf][1]] = u;
		maintain(u);
		maintain(f);
	}
	void splay(int u, int goal) {
		for (int f = fa[u]; (f = fa[u]) != goal; rotate(u)) 
			if (fa[f] != goal) rotate(get(f) == get(u) ? f : u);
		if (!goal) rt = u;
	}
	int kth(int k) {
		int cur = rt;
		while(1) {
			pushdown(cur);
			if (ch[cur][0] && k <= siz[ch[cur][0]]) cur = ch[cur][0];
			else  {
				k -= siz[ch[cur][0]] + cnt[cur];
				if (k <= 0) return cur;
				cur = ch[cur][1];
			}
		}
	}
	void reverse(int l, int r) {
		l = l - 1, r = r + 1;
		l = kth(l), r = kth(r);
		splay(l, 0); splay(r, l);
		tag[ch[r][0]] ^= 1;
	}
	void print(int u) {
		pushdown(u);
		if (ch[u][0]) print(ch[u][0]);
		if (val[u] != -INF && val[u] != INF) printf("%d ", val[u]);
		if (ch[u][1]) print(ch[u][1]);
	}
}T;
int n, m;
int main() {
	scanf("%d%d", &n, &m);
	a[1] = -INF, a[n + 2] = INF;
	for (int i = 2; i <= n + 1; i++) a[i] = i - 1;
	T.build(T.rt, 1, n + 2, 0); 
	for (int i = 1, l, r; i <= m; i++) {
		scanf("%d%d", &l, &r);
		l++, r++;
		T.reverse(l, r);
	}
	T.print(T.rt);
	return 0;
}
posted @ 2024-05-06 10:55  qqrj  阅读(52)  评论(0)    收藏  举报