Splay 模板略记
有趣的单词:OvO zig-zag OvO
基础定义和函数:
#define lc (tr[o].ch[0])
#define rc (tr[o].ch[1])
#define get(u) (tr[tr[u].fa].ch[1]==u)
int rt,tot;
struct node{
ll val;
int cnt,sz;
int fa,ch[2];
void set(ll v,int c=1,int s=1){ ch[0]=ch[1]=fa=0,val=v,cnt=c,sz=s; }
}tr[maxn];
void addson(int fa,int u,bool op){ tr[tr[u].fa=fa].ch[op]=u; }
void delson(int o){ tr[lc].fa=tr[rc].fa=tr[tr[o].fa].ch[get(o)]=0,tr[o].set(0,0,0); }
void pushup(int o){ tr[o].sz=tr[lc].sz+tr[rc].sz+tr[o].cnt; }
把节点 \(u\) 往上旋:
void rotate(int u){
int op=get(u),o=tr[u].fa,ffa=tr[o].fa,son=tr[u].ch[op^1];
addson(ffa,u,get(o)),addson(u,o,op^1),addson(o,son,op);
pushup(o),pushup(u);
}
Splay 到 \(v\) 为止:
int splay(int u,int v=0){
for(int o;(o=tr[u].fa)^v;rotate(u))
if(tr[o].fa^v) rotate(get(o)==get(u)?o:u);
if(!v) rt=u;return u;
}
如果有父亲且方向和父亲一致,要先旋父亲。如果一直转 \(u\) 复杂度就假了。
插入删除:
int insert(ll key)
{
if(!rt){ tr[++tot].set(key);return rt=tot; }
int o=getid(key);
if(key==tr[o].val){ tr[o].cnt++;pushup(o);return o; }
tr[++tot].set(key);
addson(o,tot,key>tr[o].val);
return splay(tot);
}
void remove(ll key)
{
int o=getid(key);
if(tr[o].cnt>1) tr[o].cnt--;
else{
int pre=getpre(key),nxt=getnxt(key);
splay(pre),splay(nxt,pre),delson(o),splay(nxt);
}
pushup(rt);
}
求前驱后继:
int bound(ll key,bool op){
int u=tr[insert(key)].ch[op];
while(tr[u].ch[op^1]) u=tr[u].ch[op^1];
remove(key);return u;
}
int getpre(ll key){ return bound(key,0); }
int getnxt(ll key){ return bound(key,1); }
可以先把 key 插入,此时 key 的节点已经 Splay 到根。如果要求前驱,就在 key 的左子树中不断跳右儿子。
模板题完整代码
#include<bits/stdc++.h>
#define For(i,il,ir) for(int i=(il);i<=(ir);i++)
#define Rof(i,ir,il) for(int i=(ir);i>=(il);i--)
using namespace std;
typedef long long ll;
const ll inf=1e10;
const int maxn=2e6+10;
int n,m;
ll a[maxn];
#define lc tr[o].ch[0]
#define rc tr[o].ch[1]
#define get(u) (u==tr[tr[u].fa].ch[1])
int rt,idtot;
struct Splay{
ll val;
int cnt,sz;
int fa,ch[2];
void set(ll v=0,ll c=0,ll s=0){ val=v,cnt=c,sz=s; }
}tr[maxn];
void addson(int f,int u,bool op){ tr[f].ch[op]=u,tr[u].fa=f; }
void delson(int o){ tr[tr[o].fa].ch[get(o)]=0,tr[o].fa=0,tr[o].set(); }
void pushup(int o){ tr[o].sz=tr[lc].sz+tr[rc].sz+tr[o].cnt; }
void rotate(int u){
int o=tr[u].fa,ffa=tr[o].fa,op=get(u),son=tr[u].ch[op^1];
addson(ffa,u,get(o)),addson(u,o,op^1),addson(o,son,op);
pushup(o),pushup(u);
}
int splay(int u,int v=0){
for(int o;(o=tr[u].fa)^v;rotate(u))
if(tr[o].fa^v) rotate(get(u)==get(o)?o:u);
if(!v) rt=u; return u;
}
int getid(ll key){
int u=rt,lst;
while(u)
if(key==tr[lst=u].val) return splay(u);
else u=tr[u].ch[key>tr[u].val];
return lst;
}
int insert(ll key){
if(!rt){ tr[rt=++idtot].set(key,1,1); return rt; }
int o=getid(key);
if(key==tr[o].val){
tr[o].cnt++,tr[o].sz++;
return o;
}
tr[++idtot].set(key,1,1);
addson(o,idtot,key>tr[o].val);
return splay(idtot);
}
void remove(ll key);
int bound(ll key,bool op){
int u=tr[insert(key)].ch[op];
while(tr[u].ch[op^1]) u=tr[u].ch[op^1];
remove(key); return u;
}
int getpre(ll key){ return bound(key,0); }
int getnxt(ll key){ return bound(key,1); }
void remove(ll key)
{
int o=getid(key);
if(tr[o].cnt>1) tr[o].cnt--;
else{
int pre=getpre(key),nxt=getnxt(key);
splay(pre),splay(nxt,pre),delson(o),splay(nxt);
}
pushup(rt);
}
int getrk(ll key){
int ans=tr[tr[insert(key)].ch[0]].sz+1;
remove(key); return ans;
}
ll getkey(int o,int k){
int tmp=tr[lc].sz;
if(k<=tmp) return getkey(lc,k);
else if(k>tmp+tr[o].cnt) return getkey(rc,k-tr[o].cnt-tmp);
return splay(o);
}
signed main()
{
scanf("%d%d",&n,&m);
insert(inf),insert(-inf);
For(i,1,n) scanf("%lld",&a[i]),insert(a[i]);
ll lst=0,res=0;
while(m--)
{
int op;ll x;scanf("%d%lld",&op,&x);x^=lst;
if(op==1) insert(x);
else if(op==2) remove(x);
else if(op==3) lst=(getrk(x)-1);
else if(op==4) lst=tr[getkey(rt,x+1)].val;
else if(op==5) lst=tr[splay(getpre(x))].val;
else if(op==6) lst=tr[splay(getnxt(x))].val;
if(op>2) res^=lst;
}
printf("%lld\n",res);
return 0;
}
Splay 分裂出一段区间:
int kth(int o,int k){
pushdown(o);
if(k<=tr[lc].sz) return kth(lc,k);
else if(k<=tr[lc].sz+1) return splay(o);
else return kth(rc,k-tr[lc].sz-1);
}
int split(int l,int r){
int x=kth(rt,l-1),y=kth(rt,r+1);
splay(x),splay(y,x);
return tr[y].ch[0];
}
先把 \(l-1\) 旋到根,再把 \(r+1\) 旋到根的右儿子。
区间翻转:打 tag,在查找的时候 pushdown,修改后 pushup。

浙公网安备 33010602011771号