【模版】splay

#include<bits/stdc++.h>
using namespace std;
int ch[100010][2],fa[100010],sz[100010],cnt[100010],val[100010];
int rt,id;
bool dir(int x)
{
  return x==ch[fa[x]][1];
}
void push_up(int x)
{
  sz[x]=cnt[x]+sz[ch[x][0]]+sz[ch[x][1]];
}
void rotate(int x)
{
  int y=fa[x],z=fa[y];
  bool r=dir(x);
  ch[y][r]=ch[x][!r];
  ch[x][!r]=y;
  if(z)ch[z][dir(y)]=x;
  if(ch[y][r])fa[ch[y][r]]=y;
  fa[y]=x;
  fa[x]=z;
  push_up(x);
  push_up(y);
}
void splay(int& z,int x)
{
  int w=fa[z];
  for(int y;(y=fa[x])!=w;rotate(x))
  {
    if(fa[y]!=w)rotate(dir(x)==dir(y)?y:x);
  }
  z=x;
}
void find(int &z,int v)
{
  int x=z,y=fa[x];
  for(;x&&val[x]!=v;x=ch[y=x][v>val[x]]);
  splay(z,x?x:y);
}
void loc(int& z,int k)
{
  int x=z;
  while(1)
  {
    if(sz[ch[x][0]]>=k)
    {
      x=ch[x][0];
    }
    else if(sz[ch[x][0]]+cnt[x]>=k)break;
    else{k-=sz[ch[x][0]]+cnt[x];x=ch[x][1];}
  }
  splay(z,x);
}
int merge(int x,int y)
{
  if(!x||!y)return x|y;
  loc(y,1);
  ch[y][0]=x;
  fa[x]=y;
  push_up(y);
  return y;
}
void insert(int v)
{
  int x=rt,y=0;
  for(;x&&val[x]!=v;x=ch[y=x][v>val[x]]);
  if(x){++cnt[x];++sz[x];}
  else
  {
    x=++id;
    val[x]=v;
    cnt[x]=sz[x]=1;
    fa[x]=y;
    if(y)ch[y][v>val[y]]=x;
  }
  splay(rt,x);
}
void remove(int v)
{
  find(rt,v);
  if(!rt||val[rt]!=v)return;
  --cnt[rt];
  --sz[rt];
  if(!cnt[rt])
  {
    int x=ch[rt][0],y=ch[rt][1];
    fa[x]=fa[y]=0;
    rt=merge(x,y);
  }
}
int find_rank(int v)
{
  find(rt,v);
  return sz[ch[rt][0]]+(val[rt]<v?cnt[rt]:0)+1;
}
int find_kth(int v)
{
  if(v>sz[rt])return -1;
  loc(rt,v);
  return val[rt];
}
int find_prev(int v)
{
  find(rt,v);
  if(rt&&val[rt]<v)return val[rt];
  int x=ch[rt][0];
  if(x==0)return -1;
  for(;ch[x][1];x=ch[x][1]);
  splay(rt,x);
  return val[rt];
}
int find_nxt(int v)
{
  find(rt,v);
  if(rt&&val[rt]>v)return val[rt];
  int x=ch[rt][1];
  if(x==0)return -1;
  for(;ch[x][0];x=ch[x][0]);
  splay(rt,x);
  return val[rt];
}
int main()
{
  int n,op,x;
  scanf("%d",&n);
  while(n--)
  {
    scanf("%d%d",&op,&x);
    switch(op)
    {
      case 1:
        insert(x);
        break;
      case 2:
        remove(x);
        break;
      case 3:
        printf("%d\n",find_rank(x));
        break;
      case 4:
        printf("%d\n",find_kth(x));
        break;
      case 5:
        printf("%d\n",find_prev(x));
        break;
      case 6:
        printf("%d\n",find_nxt(x));
        break;
    }
  }
}
posted @ 2025-05-28 17:46  Astral_Plane  阅读(3)  评论(0)    收藏  举报