Splay随手记
Splay(伸展树)
注意:Splay只有初始的时候满足BST的结构,翻转之后即不满足,但中序遍历依旧不变。
1.旋转方式:同Treap
2.核心:每操作一个节点,均将该节点旋转到树根
意义:每次操作平均时间复杂度是 \(O(logn)\) 的。(可严格证明:可能用势能分析?)
3.核心操作实现(Splay函数)
Splay(x,k):将点x旋转到点k下面
①当其形成一条直线的时候
先转y,再转x(把y向上转,然后把x向上转)
每转一次,x会向上走两格
②当其存在拐点的时候
转x,再转x(就是把x转两次)
只有这样转才能保证 \(O(log n)\) 的复杂度!!!不可以瞎转!!!
写的时候就是一个迭代(递归),不停向上转即可。
4.其余操作
①插入
一般情况:把x插到叶节点,然后选择到根即可
特殊:将一个序列插到y的后面(保证中序遍历在y的后面): 找到y的后继z,将y转到根(Splay(y,0)),将z转到y的下面(Splay(z,y))-> z此刻左子树为空。直接把序列插到z的左子树即可。
②删除
删除一段:删除[l,r],找到前驱l-1,后继r+1。先将l转到根节点(Splay(l-1,0)),再把r+1转到l-1下面(Splay(r+1,l-1)),此时r+1的左子树即为所要删除的段。直接置为空即可。
③其他操作
与其他平衡树无区别
5.Splay如何去维护信息
找第k个数->维护size(每棵子树大小)
翻转区间->维护lazytag(近似于线段树的懒标记)
操作:
pushup():维护值:放到旋转函数的最后
root->size=root->right->size+root->left->size+1
pushdown():下传懒标记:所有递归之前都要下传
swap(root->left,root->right)
下传标记
清空当前标记
6.模版题:文艺平衡树
[link](P3391 【模板】文艺平衡树 - 洛谷)
struct Splay{
struct NODE{
int s[2];//左右二字
int p;//父节点
int v;//编号
int flag;//懒标记:有没有翻转
int size;//大小
void init(int _v,int _p)
{
v=_v,p=_p;
size=1;
}
}tr[N];
int root=0,idx=0;
//根节点 动态分配空间
void pushup(int x)
{
tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+1;
}//维护值
void pushdown(int x)
{
if(tr[x].flag)
{
swap(tr[x].s[0],tr[x].s[1]);//交换左右儿子
tr[tr[x].s[0]].flag^=1;
tr[tr[x].s[1]].flag^=1;//下传标记
tr[x].flag=0;//当前标记清空
}
}//下传标记
void rotate(int x)
{
int y=tr[x].p,z=tr[y].p;//找到祖先
int k=tr[y].s[1]==x;//k表示x是y的左儿子还是右儿子:k=0->x是左儿子
tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y,tr[y].p=x;
pushup(y),pushup(x);
//先更新y,再更新x:因为y是x的儿子
//z的信息不用变
}//把左右旋合在一起写
void splay(int x,int k)//核心函数
{
while(tr[x].p!=k)
{
int y=tr[x].p,z=tr[y].p;
if(z!=k)
if((tr[y].s[1]==x)^(tr[z].s[1]==y))
rotate(x);//折线关系
else rotate(y);//直线关系
rotate(x);
// cout<<tr[x].p<<" "<<k<<" ";
}
if(!k) root=x;//如果是的话更新根
}//保证了Splay的时间复杂度
void insert(int v)
{
int u=root,p=0;
while(u) p=u,u=tr[u].s[v>tr[u].v];//快速判断在左/右儿子
u=++idx;
if(p) tr[p].s[v>tr[p].v]=u;//如果有父亲,更新
tr[u].init(v,p);//初始化新点
splay(u,0);//为了保证log(n)的时间复杂度,必须转到根
}//插入
int get_k(int k)
{
int u=root;
while(true)
{
pushdown(u);
if(tr[tr[u].s[0]].size>=k) u=tr[u].s[0];//在左子树里找
else if(tr[tr[u].s[0]].size+1==k) return u;//这里就是答案
else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];//在右子树里面找
}
return -1;//找不到返回-1
}//找第k个数
void output(int u)
{
pushdown(u);
if(tr[u].s[0]) output(tr[u].s[0]);//先输出左儿子
if(tr[u].v>=1 && tr[u].v<=n) cout<<tr[u].v<<" ";//输出,判断是不是哨兵
if(tr[u].s[1]) output(tr[u].s[1]);
}//输出中序遍历
}Tr;
7.简单应用
①:郁闷的出纳员
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+100,INF=1e9;
int n,m;
struct Spaly{
struct NODE{
int s[2],p,v;
int size;
void init(int _v,int _p)
{
v=_v,p=_p;
size=1;
}
}tr[N];
int root=0,idx=0,delta=0;
void pushup(int x)
{
tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+1;
}
void rotate(int x)
{
int y=tr[x].p,z=tr[y].p;
int k=tr[y].s[1]==x;
tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y,tr[y].p=x;
pushup(y),pushup(x);
}
void splay(int x,int k)
{
while(tr[x].p!=k)
{
int y=tr[x].p,z=tr[y].p;
if(z!=k)
if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
else rotate(y);
rotate(x);
}
if(!k) root=x;
}
int insert(int v)
{
int u=root,p=0;
while(u) p=u,u=tr[u].s[v>tr[u].v];
u=++idx;
if(p) tr[p].s[v>tr[p].v]=u;
tr[u].init(v,p);
splay(u,0);
return u;
}
int get(int v)
{
int u=root,res;
while(u)
{
if(tr[u].v>=v) res=u,u=tr[u].s[0];
else u=tr[u].s[1];
}
return res;
}
int get_k(int k)
{
int u=root;
while(u)
{
if(tr[tr[u].s[0]].size>=k) u=tr[u].s[0];
else if(tr[tr[u].s[0]].size+1==k) return tr[u].v;
else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
}
return -1;
}
}Tr;
signed main()
{
cin>>n>>m;
int L=Tr.insert(-INF),R=Tr.insert(INF);//建两个哨兵
int tot=0;
while(n--)
{
char op;
int k;
cin>>op>>k;
if(op=='I')
if(k>=m) k-=Tr.delta,Tr.insert(k),tot++;
if(op=='A') Tr.delta+=k;
if(op=='S')
{
Tr.delta-=k;
R=Tr.get(m-Tr.delta);
Tr.splay(R,0),Tr.splay(L,R);
Tr.tr[L].s[1]=0;
Tr.pushup(L),Tr.pushup(R);
}
if(op=='F')
{
if(Tr.tr[Tr.root].size-2<k) cout<<-1<<endl;
else cout<<Tr.get_k(Tr.tr[Tr.root].size-k)+Tr.delta<<endl;
}
}
cout<<tot-(Tr.tr[Tr.root].size-2)<<endl;
return 0;
}
②:永无乡
本题涉及到了Splay的合并,强调Splay不可以直接做合并,要用启发式合并,用Splay维护集合大小,合并集合。(启发式合并理论:把节点个数少的合并到节点个数多的,时间复杂度为 \(O(nlog n)\) )
关于合并暴力即可。每合并一个元素的时间复杂度为 \(O(log n)\) (插入),启发式合并的时间复杂度是 \(O(nlog n)\) ,总的操作时间复杂度就是 \(O(n log ^2 n)\)。
思路很简单,代码很难调。。。
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=5e5+100;
//原本的大小是1e5,但是有插入,所以要加上m的大小
int n,m;
struct Splay{
struct NODE{
int s[2],p,v,id;
int size;
void init(int _v,int _id,int _p)
{
v=_v,id=_id,p=_p;
size=1;
}
}tr[N];
int root[N],idx;
void pushup(int x)
{
tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+1;
}
void rotate(int x)
{
int y=tr[x].p,z=tr[y].p;
int k=tr[y].s[1]==x;
tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y,tr[y].p=x;
pushup(y),pushup(x);
}
void splay(int x,int k,int b)
{
while(tr[x].p!=k)
{
int y=tr[x].p,z=tr[y].p;
if(z!=k)
if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
else rotate(y);
rotate(x);
}
if(!k) root[b]=x;
}
void insert(int v,int id,int b)
{
int u=root[b],p=0;
while(u) p=u,u=tr[u].s[v>tr[u].v];
u=++idx;
if(p) tr[p].s[v>tr[p].v]=u;
tr[u].init(v,id,p);
splay(u,0,b);
}
int get_k(int k,int b)
{
int u=root[b];
while(u)
{
if(tr[tr[u].s[0]].size>=k) u=tr[u].s[0];
else if(tr[tr[u].s[0]].size+1==k)
return splay(u,0,b),tr[u].id;
else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
}
return -1;
}
void merge(int u,int b)
{
if(tr[u].s[0]) merge(tr[u].s[0],b);
if(tr[u].s[1]) merge(tr[u].s[1],b);
insert(tr[u].v,tr[u].id,b);
}
}Tr;
int f[N];
int find(int x)
{
if(f[x]!=x) f[x]=find(f[x]);
return f[x];
}
signed main()
{
cin>>n>>m;
for(int i=1;i<=n;i++)
{
f[i]=Tr.root[i]=i;
int v;
cin>>v;
Tr.tr[i].init(v,i,0);
}
Tr.idx=n;
while(m--)
{
int a,b;
cin>>a>>b;
a=find(a),b=find(b);
if(a!=b)
{
if(Tr.tr[Tr.root[a]].size>Tr.tr[Tr.root[b]].size) swap(a,b);
Tr.merge(Tr.root[a],b);
f[a]=b;
}
}
cin>>m;
while(m--)
{
char op;
int a,b;
cin>>op>>a>>b;
if(op=='B')
{
a=find(a),b=find(b);
if(a!=b)
{
if(Tr.tr[Tr.root[a]].size>Tr.tr[Tr.root[b]].size) swap(a,b);
Tr.merge(Tr.root[a],b);
f[a]=b;
}
}
else
{
a=find(a);
if(Tr.tr[Tr.root[a]].size<b) cout<<-1<<endl;
else cout<<Tr.get_k(b,a)<<endl;
}
}
return 0;
}
③:维护数列
内存回收:
把删掉的点再利用:建一个数组(回收站),把被删的点记录下来(其实是记录可用下标),要用的时候从数组里面找一个即可。提高空间利用率。
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=5e5+100,INF=1e18;
int n,m;
int w[N];
struct Splay{
struct NODE{
int s[2],p,v;
int rev,same;
int size,sum,ms,ls,rs;
void init(int _v,int _p)
{
s[0]=s[1]=0,p=_p,v=_v;
rev=same=0;
size=1,sum=ms=v;
ls=rs=max(v,0ll);
}
}tr[N];
int root;
int nodes[N],tt;//回收站
void init_nodes()
{
for(int i=1;i<N;i++)
nodes[++tt]=i;
}
void pushup(int x)
{
auto &u=tr[x],&l=tr[u.s[0]],&r=tr[u.s[1]];
u.size=l.size+r.size+1;
u.sum=l.sum+r.sum+u.v;
u.ls=max(l.ls,l.sum+u.v+r.ls);
u.rs=max(r.rs,r.sum+u.v+l.rs);
u.ms=max({l.ms,r.ms,l.rs+u.v+r.ls});
}
void pushdown(int x)
{
auto &u=tr[x],&l=tr[u.s[0]],&r=tr[u.s[1]];
if(u.same)
{
u.same=u.rev=0;
if(u.s[0]) l.same=1,l.v=u.v,l.sum=l.v*l.size;
if(u.s[1]) r.same=1,r.v=u.v,r.sum=r.v*r.size;
if(u.v>0)
{
if(u.s[0]) l.ms=l.ls=l.rs=l.sum;
if(u.s[1]) r.ms=r.ls=r.rs=r.sum;
}
else
{
if(u.s[0]) l.ms=l.v,l.ls=l.rs=0;
if(u.s[1]) r.ms=r.v,r.ls=r.rs=0;
}
}
else if(u.rev)
{
u.rev=0,l.rev^=1,r.rev^=1;
swap(l.ls,l.rs),swap(r.ls,r.rs);
swap(l.s[0],l.s[1]),swap(r.s[0],r.s[1]);
}
}
void rotate(int x)
{
int y=tr[x].p,z=tr[y].p;
int k=tr[y].s[1]==x;
tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y,tr[y].p=x;
pushup(y),pushup(x);
}
void splay(int x,int k)
{
while(tr[x].p!=k)
{
int y=tr[x].p,z=tr[y].p;
if(z!=k)
if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
else rotate(y);
rotate(x);
}
if(!k) root=x;
}
int get_k(int k)
{
int u=root;
while(u)
{
pushdown(u);
if(tr[tr[u].s[0]].size>=k) u=tr[u].s[0];
else if(tr[tr[u].s[0]].size+1==k) return u;
else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
}
return -1;
}
int build(int p,int l,int r)
{
int mid=l+r>>1;
int u=nodes[tt--];
tr[u].init(w[mid],p);
if(l<mid) tr[u].s[0]=build(u,l,mid-1);
if(mid<r) tr[u].s[1]=build(u,mid+1,r);
pushup(u);
return u;
}
void dfs(int u)
{
if(tr[u].s[0]) dfs(tr[u].s[0]);
if(tr[u].s[1]) dfs(tr[u].s[1]);
nodes[++tt]=u;
}
}Tr;
signed main()
{
Tr.init_nodes();
cin>>n>>m;
Tr.tr[0].ms=w[0]=w[n+1]=-INF;
for(int i=1;i<=n;i++) cin>>w[i];
Tr.root=Tr.build(0,0,n+1);
char op[30];
while(m--)
{
cin>>op;
if(!strcmp(op,"INSERT"))
{
int posi,tot;
cin>>posi>>tot;
for(int i=0;i<tot;i++) cin>>w[i];
int l=Tr.get_k(posi+1),r=Tr.get_k(posi+2);
Tr.splay(l,0),Tr.splay(r,l);
int u=Tr.build(r,0,tot-1);
Tr.tr[r].s[0]=u;
Tr.pushup(r),Tr.pushup(l);
}
else if(!strcmp(op,"DELETE"))
{
int posi,tot;
cin>>posi>>tot;
int l=Tr.get_k(posi),r=Tr.get_k(posi+tot+1);
Tr.splay(l,0),Tr.splay(r,l);
Tr.dfs(Tr.tr[r].s[0]);
Tr.tr[r].s[0]=0;
Tr.pushup(r),Tr.pushup(l);
}
else if(!strcmp(op,"MAKE-SAME"))
{
int posi,tot,c;
cin>>posi>>tot>>c;
int l=Tr.get_k(posi),r=Tr.get_k(posi+tot+1);
Tr.splay(l,0),Tr.splay(r,l);
auto &son=Tr.tr[Tr.tr[r].s[0]];
son.same=1,son.v=c,son.sum=c*son.size;
if(c>0) son.ms=son.ls=son.rs=son.sum;
else son.ms=c,son.ls=son.rs=0;
Tr.pushup(r),Tr.pushup(l);
}
else if(!strcmp(op,"REVERSE"))
{
int posi,tot;
cin>>posi>>tot;
int l=Tr.get_k(posi),r=Tr.get_k(posi+tot+1);
Tr.splay(l,0),Tr.splay(r,l);
auto &son=Tr.tr[Tr.tr[r].s[0]];
son.rev^=1;
swap(son.ls,son.rs);
swap(son.s[0],son.s[1]);
Tr.pushup(r),Tr.pushup(l);
}
else if(!strcmp(op,"GET-SUM"))
{
int posi,tot;
cin>>posi>>tot;
int l=Tr.get_k(posi),r=Tr.get_k(posi+tot+1);
Tr.splay(l,0),Tr.splay(r,l);
cout<<Tr.tr[Tr.tr[r].s[0]].sum<<endl;
}
else cout<<Tr.tr[Tr.root].ms<<endl;
}
return 0;
}