BZOJ 1861: [Zjoi2006]Book 书架 | SPlay 板题

#include<cstdio>
#include<algorithm>
#include<cstring>
#define N 80010
#define which(x) (ls[fa[(x)]]==(x))
using namespace std;
int id[N],pos[N],sz[N],ls[N],rs[N],fa[N],tot,n,m,root,a[N];
char op[233];
void pushup(int x) {sz[x]=1+sz[ls[x]]+sz[rs[x]];}
int Build(int l,int r,int pre)
{
    if (l>r) return 0;
    int mid=l+r>>1,u=a[mid];
    ls[u]=Build(l,mid-1,u);
    rs[u]=Build(mid+1,r,u);
    fa[u]=pre;
    pushup(u);
    return u;
}
void Rotate(int u){
    int v = fa[u], w = fa[v], b = which(u) ? rs[u] : ls[u];
    if(w) which(v) ? ls[w] = u : rs[w] = u;
    which(u) ? (ls[v] = b, rs[u] = v) : (rs[v] = b, ls[u] = v);
    fa[u] = w, fa[v] = u;
    if(b) fa[b] = v;
    pushup(v), pushup(u);
}
void Splay(int u, int tar){
    while(fa[u] != tar){
    if(fa[fa[u]] != tar){
        if(which(u) == which(fa[u])) Rotate(fa[u]);
        else Rotate(u);
    }
    Rotate(u);
    }
    if(!tar) root = u;
}
int getrank(int u)
{
    Splay(u,0);
    return sz[ls[u]];
}
int getkth(int k)
{
    int cur=root;
    while (sz[ls[cur]]+1!=k)
    {
    if (k<=sz[ls[cur]]) cur=ls[cur];
    else k-=sz[ls[cur]]+1,cur=rs[cur];
    }
    return cur;
}
void insert(int v,int top)
{
    int u=root;
    while (top?ls[u]:rs[u])
    u=top?ls[u]:rs[u];
    fa[v]=u,top?ls[u]=v:rs[u]=v,sz[v]=1;
    ls[v]=rs[v]=0;
    Splay(v,0);
}
void erase(int u)
{
    Splay(u,0);
    if (sz[u]==0) root;
    else if (!ls[u] || !rs[u]) root=ls[u]+rs[u],fa[root]=0;
    else
    {
    fa[ls[u]]=0;
    int v=ls[u];
    while (rs[v]) v=rs[v];
    Splay(v,0);
    rs[v]=rs[u],fa[rs[u]]=v,pushup(v);
    }
}
void swp(int u,int x)
{
    Splay(u,0);
    int v=x<0?ls[u]:rs[u];
    while (x<0?rs[v]:ls[v]) v=x<=0?rs[v]:ls[v];
    Splay(v,u);
    if (x<0) swap(rs[u],rs[v]),ls[u]=ls[v],ls[v]=u;
    else swap(ls[u],ls[v]),rs[u]=rs[v],rs[v]=u;
    fa[v]=0,fa[ls[u]]=fa[rs[u]]=u,fa[ls[v]]=fa[rs[v]]=v;
    pushup(u),pushup(v);
    root=v;
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1,x;i<=n;i++)
    scanf("%d",&a[i]);
    root=Build(1,n,0);
    for (int i=1,t,x;i<=m;i++)
    {
    scanf("%s%d",op,&x);
    if (op[0]=='T') erase(x),insert(x,1);
    if (op[0]=='B') erase(x),insert(x,0);
    if (op[0]=='A') printf("%d\n",getrank(x));
    if (op[0]=='Q') printf("%d\n",getkth(x));
    if (op[0]=='I')
    {
        scanf("%d",&t);
        if (t) swp(x,t);
    }
    }
    return 0;
}

 

posted @ 2017-12-27 19:54  MSPqwq  阅读(143)  评论(0编辑  收藏  举报