Splay学习笔记

问题分析:
(来源:洛谷P3369【模板】普通平衡树)

您需要动态地维护一个可重集合 \(M\),并且提供以下操作:

  1. \(M\) 中插入一个数 \(x\)
  2. \(M\) 中删除一个数 \(x\)(若有多个相同的数,应只删除一个)。
  3. 查询 \(M\) 中有多少个数比 \(x\) 小,并且将得到的答案加一。
  4. 查询如果将 \(M\) 从小到大排列后,排名位于第 \(x\) 位的数。
  5. 查询 \(M\)\(x\) 的前驱(前驱定义为小于 \(x\),且最大的数)。
  6. 查询 \(M\)\(x\) 的后继(后继定义为大于 \(x\),且最小的数)。

平衡二叉树:

平衡二叉树满足这样的性质:对于每个节点x,其左子树的树高与右子树的树高相差不超过1。(实际应用中,往往不需要严格保证左右子树高差不超过1)
要实现这个操作,需要在构建二叉树的同时进行特殊操作,保证既不破坏二叉搜索树的性质,又能对树的形态进行调整。
实现这一方法的操作称为旋转,分为左旋和右旋。
以左旋为例,若根节点为x,左儿子为y,则需要在左旋后将根节点变为y,x成为y的右儿子。
旋转的流程为:
1.将y的右儿子接在x的左子树的位置
2.将y接在x的右子树的位置
3.将节点x接在y的父节点的对应位置
右旋同理

Splay:

Splay树,又称伸展树,是一种实现平衡树的方法,它在每次操作后,强制将节点旋至根节点。该操作称为伸展。伸展操作可以保证树的高度为O(logn)的级别。
伸展的实现依靠旋转,基于当前节点x,对x的父节点y,y的父节点z进行操作,分为以下三种

1.zip

此操作仅用于z不存在的情况,即y为根节点,此时对x进行旋转操作。

2.zip-zip

此操作用于x相对y,y相对z在同侧,即都为左儿子或右儿子。此时先对y进行旋转操作,再对x进行旋转操作。

3.zip-zap

此操作用于x相对y,y相对z在异侧。此时先对x进行旋转操作,再对y进行旋转操作。
通过这种方式能在不破坏二叉树性质的前提下将当前节点移至根节点。

代码分析:
旋转操作

按前述即可,取x,x的父节点y,y的父节点z操作。注意判断节点的存在性。

void rot(int x){
	int y = fa[x], z = fa[y];
	int d = dir(x);//判断是父节点的左or右孩子
	ch[y][d] = ch[x][d ^ 1];
	ch[x][d ^ 1] = y;
	if(z) ch[z][dir(y)] = x;
	if(ch[y][d]) fa[ch[y][d]] = y;
	fa[y] = x;
	fa[x] = z;
	pushup(y);//注意先更新子节点,再更新父节点
	pushup(x);
}
伸展操作:

按照前述实现即可,类似冒泡做法,使用循环不断向根节点更新。

void splay(int x){
	int y = fa[x];
	while(y){//上浮至根节点
		if(fa[y]) rot(dir(x) == dir(y) ? y : x);//区分三种旋转操作
		rot(x);
		y = fa[x];
	}
	root = x;//修改记录值

}

查找指定节点的位置:

依据二叉搜索树的性质即可,注意最后的伸展操作。

void find(int v){
	int x = root, y = 0;
	while(x && val[x] != v){//直接迭代即可
		y = x;
		x = ch[x][v > val[x]];
	}
	splay(x ? x : y);//注意如果x没找到就返回父节点
}
插入操作:

先利用二叉搜索树的性质找到节点应该插入的位置,再插入。注意最后的伸展操作。

void insert(int v){
	int x = root, y = 0;
	while(x && val[x] != v){
		y = x;
		x = ch[x][v > val[x]];
	}
	if(x){//x已存在
		cnt[x]++;
		siz[x]++;
	}
	else{//x不存在,添加一个x
		x = add(y, v);
	}
	splay(x);
}
查找某个排名的元素的位置:

利用二叉查找树的性质即可,按照子树大小和当前节点计数分类讨论。注意最后的伸展操作。

void loc(int v){
	int x = root;
	while(1){
		if(v <= siz[ch[x][0]]){//在左子树
			x = ch[x][0];
		}
		else if(v <= siz[ch[x][0]] + cnt[x]){//在当前节点范围内
			break;
		}
		else{//在右子树
			v -= siz[ch[x][0]] + cnt[x];
			x = ch[x][1];
		}
	}
	splay(x);
}
合并操作:

要求左子树的最大值小于右子树的最小值。将右子树的最小值作为根节点,把左子树接在根节点的左侧即可。

void merge(int x, int y){
	if(!x || !y){//如果有空子树就不用合并了
		root = x | y;
		return ;
	}
	root = y;
	loc(1);//找右子树的最小值设为根节点
	ch[root][0] = x;//将左子树接上
	fa[x] = root;
	pushup(root);
}
删除操作:

先找到要删除的值所在节点,由于伸展操作,该节点自动成为根节点。若计数为1则利用合并操作将其删除即可。

void remove(int v){
	find(v);
	if(!root || val[root] != v) return ;
	cnt[root]--;
	siz[root]--;
	if(!cnt[root]){//这个节点已经不存在了,由于已经是根节点,直接重新合并左右子树即可
		int x = ch[root][0];
		int y = ch[root][1];
		fa[x] = fa[y] = 0;
		merge(x, y);//合并左右子树
	}
}
查询值的排名:

利用前面的loc函数即可。

int getval(int v){
	loc(v);
	return val[root];
}
查询排名的值:

利用前面的find函数即可,注意值不存在的情况特殊处理。

int getrank(int v){
	find(v);
	return siz[ch[root][0]] + (v <= val[root] ? 0 : cnt[root]) + 1;
	//值可能不存在,如果这个值比当前节点大则需要再加上当前节点的数量
}
查询前驱后继:

先利用find函数找到最近的值,若比查询值小则为答案,否则找到其左子树最右侧的值即可。

int getpre(int v){
	find(v);
	if(root && val[root] < v) return val[root];
	int x = ch[root][0];
	if(!x) return -1;
	while(ch[x][1]) x = ch[x][1];//找最右侧的值,相当于loc
	splay(x);
	return val[root];
}

查询后继同理:

int getnext(int v){
	find(v);
	if(root && v < val[root]) return val[root];
	int x = ch[root][1];
	if(!x) return -1;
	while(ch[x][0]) x = ch[x][0];//找最左侧的值
	splay(x);
	return val[root];
}

完整代码如下:

#include <bits/stdc++.h>
#define MAX 100005
using namespace std;
int n;
int root;
int fa[MAX];
int ch[MAX][2];
int val[MAX];
int cnt[MAX];
int siz[MAX];
int tot;
void pushup(int x){
	siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];
}
int add(int y, int v){
	tot++;
	val[tot] = v;
	siz[tot] = cnt[tot] = 1;
	fa[tot] = y;
	if(y) ch[y][v > val[y]] = tot;
	return tot;
}
int dir(int x){
	return x == ch[fa[x]][1];
}
void rot(int x){
	int y = fa[x], z = fa[y];
	int d = dir(x);
	ch[y][d] = ch[x][d ^ 1];
	ch[x][d ^ 1] = y;
	if(z) ch[z][dir(y)] = x;
	if(ch[y][d]) fa[ch[y][d]] = y;
	fa[y] = x;
	fa[x] = z;
	pushup(y);
	pushup(x);
}
void splay(int x){
	int y = fa[x];
	while(y){
		if(fa[y]) rot(dir(x) == dir(y) ? y : x);
		rot(x);
		y = fa[x];
	}
	root = x;
}
void find(int v){
	int x = root, y = 0;
	while(x && val[x] != v){
		y = x;
		x = ch[x][v > val[x]];
	}
	splay(x ? x : y);
}
void insert(int v){
	int x = root, y = 0;
	while(x && val[x] != v){
		y = x;
		x = ch[x][v > val[x]];
	}
	if(x){
		cnt[x]++;
		siz[x]++;
	}
	else{
		x = add(y, v);
	}
	splay(x);
}
void loc(int v){
	int x = root;
	while(1){
		if(v <= siz[ch[x][0]]){
			x = ch[x][0];
		}
		else if(v <= siz[ch[x][0]] + cnt[x]){
			break;
		}
		else{
			v -= siz[ch[x][0]] + cnt[x];
			x = ch[x][1];
		}
	}
	splay(x);
}
void merge(int x, int y){
	if(!x || !y){
		root = x | y;
		return ;
	}
	root = y;
	loc(1);
	ch[root][0] = x;
	fa[x] = root;
	pushup(root);
}
void remove(int v){
	find(v);
	if(!root || val[root] != v) return ;
	cnt[root]--;
	siz[root]--;
	if(!cnt[root]){
		int x = ch[root][0];
		int y = ch[root][1];
		fa[x] = fa[y] = 0;
		merge(x, y);
	}
}
int getrank(int v){
	find(v);
	return siz[ch[root][0]] + (v <= val[root] ? 0 : cnt[root]) + 1;
}
int getval(int v){
	loc(v);
	return val[root];
}
int getpre(int v){
	find(v);
	if(root && val[root] < v) return val[root];
	int x = ch[root][0];
	if(!x) return -1;
	while(ch[x][1]) x = ch[x][1];
	splay(x);
	return val[root];
}
int getnext(int v){
	find(v);
	if(root && v < val[root]) return val[root];
	int x = ch[root][1];
	if(!x) return -1;
	while(ch[x][0]) x = ch[x][0];
	splay(x);
	return val[root];
}
int main(){
	scanf("%d", &n);
	for(int i = 1; i <= n; i++){
		int ty, x;
		scanf("%d%d", &ty, &x);
		if(ty == 1){
			insert(x);
		}
		else if(ty == 2){
			remove(x);
		}
		else if(ty == 3){
			printf("%d\n", getrank(x));
		}
		else if(ty == 4){
			printf("%d\n", getval(x));
		}
		else if(ty == 5){
			printf("%d\n", getpre(x));
		}
		else{
			printf("%d\n", getnext(x));
		}
	}
	return 0;
}

参考资料:oi-wiki

posted @ 2025-10-11 12:50  LIGHTB  阅读(16)  评论(0)    收藏  举报