[学习笔记]树套树 线段树套Splay

今天调了一个早上哈哈哈,不过因为\(Splay\),常数比较大

洛谷的评测记录:

\(Code\ Below:\)

#include <bits/stdc++.h>
#define ll long long
#define mid (l+r>>1)
using namespace std;
const int maxn=4000010;
const int inf=2147483647;
int n,m,sz,ans,Max;
int a[maxn],ch[maxn][2],fa[maxn],siz[maxn],cnt[maxn],key[maxn],rt[maxn];

inline void splayclear(int x){
    fa[x]=ch[x][0]=ch[x][1]=siz[x]=cnt[x]=key[x]=0;
}
inline bool get(int x){
    return ch[fa[x]][1]==x;
}
inline void update(int x){
    siz[x]=(ch[x][0]?siz[ch[x][0]]:0)+(ch[x][1]?siz[ch[x][1]]:0)+cnt[x];
}
inline void rotate(int x){
    int y=fa[x],z=fa[y],k=get(x);
    ch[y][k]=ch[x][k^1];fa[ch[y][k]]=y;
    ch[x][k^1]=y;fa[y]=x;fa[x]=z;
    if(z) ch[z][ch[z][1]==y]=x;
    update(y);update(x);
}
inline void splay(int i,int x,int top){
    for(int y,z;fa[x]!=top;rotate(x)){
        y=fa[x],z=fa[y];
        if(z!=top) rotate(get(x)==get(y)?y:x);
    }
    if(!top) rt[i]=x;
}
inline void splayinsert(int i,int val){
    int x=rt[i],y;
    if(!rt[i]){
        rt[i]=x=++sz;fa[x]=ch[x][0]=ch[x][1]=0;
        key[x]=val;siz[x]=cnt[x]=1;return;
    }
    while(1){
        if(key[x]==val){cnt[x]++;update(y);break;}
        y=x;x=ch[x][key[x]<val];
        if(x==0){
            x=++sz;key[x]=val;siz[x]=cnt[x]=1;
            ch[y][key[y]<val]=x;fa[x]=y;
            update(y);break;
        }
    }
    splay(i,x,0);
}
inline int splayrank(int i,int val){
    int x=rt[i],ret=0;
    while(x>0){
        if(val<key[x]) x=ch[x][0];
        else {
            ret+=(ch[x][0]?siz[ch[x][0]]:0);
            if(val==key[x]) return ret;
            ret+=cnt[x];x=ch[x][1];
        }
    }
    return ret;
}
inline int splayfind(int i,int val){
    int x=rt[i];
    while(x>0){
        if(val==key[x]){splay(i,x,0);return x;}
        x=ch[x][key[x]<val];
    }
}
inline int splaypre(int i){int x=ch[rt[i]][0];while(ch[x][1])x=ch[x][1];return x;}
inline int splaysuc(int i){int x=ch[rt[i]][1];while(ch[x][0])x=ch[x][0];return x;}
inline void splaydel(int i,int val){
    int x=splayfind(i,val),y=rt[i];
    if(cnt[x]>1){cnt[x]--;update(x);return;}
    if(!ch[x][0]&&!ch[x][1]){splayclear(rt[i]);rt[i]=0;return;}
    if(!ch[x][0]){y=ch[x][1];rt[i]=y;fa[y]=0;return;}
    if(!ch[x][1]){y=ch[x][0];rt[i]=y;fa[y]=0;return;}
    splay(i,splaypre(i),0);
    ch[rt[i]][1]=ch[y][1];fa[ch[y][1]]=rt[i];
    splayclear(y);update(rt[i]);
}
inline int splaygetpre(int i,int val){
    int x=rt[i];
    while(x>0){
        if(val<=key[x]) x=ch[x][0];
        else {
            if(key[x]>ans) ans=key[x];
            x=ch[x][1];
        }
    }
    return ans;
}
inline int splaygetsuc(int i,int val){
    int x=rt[i];
    while(x>0){
        if(val>=key[x]) x=ch[x][1];
        else {
            if(key[x]<ans) ans=key[x];
            x=ch[x][0];
        }
    }
    return ans;
}

inline void seginsert(int l,int r,int rt,int x,int val){
    splayinsert(rt,val);
    if(l == r) return;
    if(x <= mid) seginsert(l,mid,rt<<1,x,val);
    else seginsert(mid+1,r,rt<<1|1,x,val);
}
inline void segrank(int L,int R,int val,int l,int r,int rt){
    if(L <= l && r <= R){ans+=splayrank(rt,val);return;}
    if(L <= mid) segrank(L,R,val,l,mid,rt<<1);
    if(R > mid) segrank(L,R,val,mid+1,r,rt<<1|1);
}
inline void segchange(int l,int r,int rt,int x,int val){
    splaydel(rt,a[x]);splayinsert(rt,val);
    if(l == r){a[x]=val;return;}
    if(x <= mid) segchange(l,mid,rt<<1,x,val);
    else segchange(mid+1,r,rt<<1|1,x,val);
}
inline void segpre(int L,int R,int val,int l,int r,int rt){
    if(L <= l && r <= R){ans=max(ans,splaygetpre(rt,val));return;}
    if(L <= mid) segpre(L,R,val,l,mid,rt<<1);
    if(R > mid) segpre(L,R,val,mid+1,r,rt<<1|1);
}
inline void segsuc(int L,int R,int val,int l,int r,int rt){
    if(L <= l && r <= R){ans=min(ans,splaygetsuc(rt,val));return;}
    if(L <= mid) segsuc(L,R,val,l,mid,rt<<1);
    if(R > mid) segsuc(L,R,val,mid+1,r,rt<<1|1);
}
inline int getkth(int L,int R,int k){
    int l=0,r=Max+1,Mid;
    while(l<r){
        Mid=l+r>>1;ans=0;
        segrank(L,R,Mid,1,n,1);
        if(ans<k) l=Mid+1;
        else r=Mid;
    }
    return l-1;
}

int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i]);
        seginsert(1,n,1,i,a[i]);
        Max=max(Max,a[i]);
    }
    int opt,l,r,val;
    while(m--){
        scanf("%d%d%d",&opt,&l,&r);
        switch(opt){
            case 1:scanf("%d",&val);ans=0;segrank(l,r,val,1,n,1);printf("%d\n",ans+1);break;
            case 2:scanf("%d",&val);printf("%d\n",getkth(l,r,val));break;
            case 3:segchange(1,n,1,l,r);break;
            case 4:scanf("%d",&val);ans=-inf;segpre(l,r,val,1,n,1);printf("%d\n",ans);break;
            case 5:scanf("%d",&val);ans=inf;segsuc(l,r,val,1,n,1);printf("%d\n",ans);break;
        }
    }
    return 0;	
}
posted @ 2018-09-29 16:38  Owen_codeisking  阅读(175)  评论(0编辑  收藏  举报