Splay 模板略记

有趣的单词:OvO zig-zag OvO

基础定义和函数:

#define lc (tr[o].ch[0])
#define rc (tr[o].ch[1])
#define get(u) (tr[tr[u].fa].ch[1]==u)
int rt,tot;
struct node{
    ll val;
    int cnt,sz;
    int fa,ch[2];
    void set(ll v,int c=1,int s=1){ ch[0]=ch[1]=fa=0,val=v,cnt=c,sz=s; }
}tr[maxn];
void addson(int fa,int u,bool op){ tr[tr[u].fa=fa].ch[op]=u; }
void delson(int o){ tr[lc].fa=tr[rc].fa=tr[tr[o].fa].ch[get(o)]=0,tr[o].set(0,0,0); }
void pushup(int o){ tr[o].sz=tr[lc].sz+tr[rc].sz+tr[o].cnt; }

把节点 \(u\) 往上旋:

void rotate(int u){
    int op=get(u),o=tr[u].fa,ffa=tr[o].fa,son=tr[u].ch[op^1];
    addson(ffa,u,get(o)),addson(u,o,op^1),addson(o,son,op);
    pushup(o),pushup(u);
}

Splay 到 \(v\) 为止:

int splay(int u,int v=0){
    for(int o;(o=tr[u].fa)^v;rotate(u)) 
        if(tr[o].fa^v) rotate(get(o)==get(u)?o:u);
    if(!v) rt=u;return u;
}

如果有父亲且方向和父亲一致,要先旋父亲。如果一直转 \(u\) 复杂度就假了。

插入删除:

int insert(ll key)
{
    if(!rt){ tr[++tot].set(key);return rt=tot; }
	int o=getid(key);
    if(key==tr[o].val){ tr[o].cnt++;pushup(o);return o; }
    tr[++tot].set(key);
	addson(o,tot,key>tr[o].val);
	return splay(tot);
}
void remove(ll key)
{
	int o=getid(key);
    if(tr[o].cnt>1) tr[o].cnt--;
    else{
        int pre=getpre(key),nxt=getnxt(key);
        splay(pre),splay(nxt,pre),delson(o),splay(nxt);
    }
    pushup(rt);
}

求前驱后继:

int bound(ll key,bool op){
	int u=tr[insert(key)].ch[op];
	while(tr[u].ch[op^1]) u=tr[u].ch[op^1];
	remove(key);return u;
}
int getpre(ll key){ return bound(key,0); }
int getnxt(ll key){ return bound(key,1); }

可以先把 key 插入,此时 key 的节点已经 Splay 到根。如果要求前驱,就在 key 的左子树中不断跳右儿子。

模板题完整代码
#include<bits/stdc++.h>
#define For(i,il,ir) for(int i=(il);i<=(ir);i++)
#define Rof(i,ir,il) for(int i=(ir);i>=(il);i--)
using namespace std;
typedef long long ll;
const ll inf=1e10;
const int maxn=2e6+10;

int n,m;
ll a[maxn];

#define lc tr[o].ch[0]
#define rc tr[o].ch[1]
#define get(u) (u==tr[tr[u].fa].ch[1])

int rt,idtot;
struct Splay{
    ll val;
    int cnt,sz;
    int fa,ch[2];
    void set(ll v=0,ll c=0,ll s=0){ val=v,cnt=c,sz=s; }
}tr[maxn];
void addson(int f,int u,bool op){ tr[f].ch[op]=u,tr[u].fa=f; }
void delson(int o){ tr[tr[o].fa].ch[get(o)]=0,tr[o].fa=0,tr[o].set(); }
void pushup(int o){ tr[o].sz=tr[lc].sz+tr[rc].sz+tr[o].cnt; }

void rotate(int u){
    int o=tr[u].fa,ffa=tr[o].fa,op=get(u),son=tr[u].ch[op^1];
    addson(ffa,u,get(o)),addson(u,o,op^1),addson(o,son,op);
    pushup(o),pushup(u);
}
int splay(int u,int v=0){
    for(int o;(o=tr[u].fa)^v;rotate(u))
        if(tr[o].fa^v) rotate(get(u)==get(o)?o:u);
    if(!v) rt=u; return u;
}

int getid(ll key){
    int u=rt,lst;
    while(u)
        if(key==tr[lst=u].val) return splay(u);
        else u=tr[u].ch[key>tr[u].val];
	return lst;
}
int insert(ll key){
    if(!rt){ tr[rt=++idtot].set(key,1,1); return rt; }
    int o=getid(key);
    if(key==tr[o].val){
        tr[o].cnt++,tr[o].sz++;
        return o;
    }
    tr[++idtot].set(key,1,1);
    addson(o,idtot,key>tr[o].val);
    return splay(idtot);
}
void remove(ll key);
int bound(ll key,bool op){
	int u=tr[insert(key)].ch[op];
	while(tr[u].ch[op^1]) u=tr[u].ch[op^1];
	remove(key); return u;
}
int getpre(ll key){ return bound(key,0); }
int getnxt(ll key){ return bound(key,1); }
void remove(ll key)
{
	int o=getid(key);
    if(tr[o].cnt>1) tr[o].cnt--;
    else{
        int pre=getpre(key),nxt=getnxt(key);
        splay(pre),splay(nxt,pre),delson(o),splay(nxt);
    }
    pushup(rt);
}
int getrk(ll key){
    int ans=tr[tr[insert(key)].ch[0]].sz+1;
	remove(key); return ans;
}
ll getkey(int o,int k){
    int tmp=tr[lc].sz;
    if(k<=tmp) return getkey(lc,k);
    else if(k>tmp+tr[o].cnt) return getkey(rc,k-tr[o].cnt-tmp);
	return splay(o);
}

signed main()
{
    scanf("%d%d",&n,&m);
    insert(inf),insert(-inf);
    For(i,1,n) scanf("%lld",&a[i]),insert(a[i]);
    
    ll lst=0,res=0;
    while(m--)
    {
        int op;ll x;scanf("%d%lld",&op,&x);x^=lst;
        if(op==1) insert(x);
        else if(op==2) remove(x);
        else if(op==3) lst=(getrk(x)-1);
        else if(op==4) lst=tr[getkey(rt,x+1)].val;
        else if(op==5) lst=tr[splay(getpre(x))].val;
        else if(op==6) lst=tr[splay(getnxt(x))].val;
        if(op>2) res^=lst;
    }
    printf("%lld\n",res);
    return 0;
}

Splay 分裂出一段区间:

int kth(int o,int k){
    pushdown(o);
    if(k<=tr[lc].sz) return kth(lc,k);
    else if(k<=tr[lc].sz+1) return splay(o);
    else return kth(rc,k-tr[lc].sz-1);
}
int split(int l,int r){
    int x=kth(rt,l-1),y=kth(rt,r+1);
    splay(x),splay(y,x);
    return tr[y].ch[0];
}

先把 \(l-1\) 旋到根,再把 \(r+1\) 旋到根的右儿子。

区间翻转:打 tag,在查找的时候 pushdown,修改后 pushup。

posted @ 2025-04-01 15:47  wanggk  阅读(47)  评论(0)    收藏  举报