【二逼平衡树】树套树模板,线段树套平衡树

树套树入门?第一份代码达到7k,orz orz orz,我的splay总是会断owo。 LOJ106
这是一道模板题。 您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
  1. 查询 x 在区间内的排名;
  2. 查询区间内排名为 k 的值;
  3. 修改某一位置上的数值;
  4. 查询 x 在区间内的前趋(前趋定义为小于 x,且最大的数);
  5. 查询 x 在区间内的后继(后继定义为大于 x ,且最小的数)。

输入格式

第一行两个数 n,m ,表示长度为 n 的有序序列和 m个操作。 第二行有 n 个数,表示有序序列。 下面有 m 行,每行第一个数表示操作类型:
  1. 之后有三个数 l,r,x 表示查询 x 在区间 [l,r]的排名;
  2. 之后有三个数 l,r,k 表示查询区间 [l,r]内排名为 k 的数;
  3. 之后有两个数 pos,x 表示将 pos \mathrm{pos} pos 位置的数修改为 x
  4. 之后有三个数 l,r,x 表示查询区间 [l,r]x 的前趋;
  5. 之后有三个数 l,r,x 表示查询区间 [l,r] x 的后继。

输出格式

对于操作 1,2,4,5 各输出一行,表示查询结果。

样例

样例输入

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

样例输出

2
4
3
4
9

数据范围与提示

1≤n,m≤5×104,−108≤k,x≤108 题目很显然,就是一道毒瘤数据结构。 做法1(本蒟蒻的做法):树套树,线段树套平衡树(splay),对每个位置建线段树,线段树每个结点对应的l到r建立一颗平衡树。由于层数Logn,建树时间O(n log^2 n) 对于询问,我们把l到r区间内所有对应包含的线段树结点对应splay的root结点提取出来,然后 1,对每个splay找比他小的个数总和+1即可。 O(n log^2 n) 2.我们考虑二分,对mid找比他小的数总和然后调整即可,注意小心二分写法(蒟蒻太久没有手写过二分查找了)O(n log^3 n) 3,对包含这个位置所有的平衡树进行del后insert , O(nlog^2 n) 4.对所有区间内平衡树找前驱,取个max就好 5.同4,找所有后继,取个min就好 空间复杂度在于splay,由于线段树logn层,每层n个数,空间(nlogn) 时间 瓶颈在于操作2,时间(nlog^3n) 做法2:(考虑怎么优化操作2) 来自Oblack大佬%%% 同样是线段树套平衡树,只是将线段树维护的转化为数值(动态开点),然后对于每个线段树结点开平衡树维护这个结点的数值对应的位置。可以发现这样的话操作2就转化为了nlog^2n。 当然我们转过头来看原nlog^3n做法,我们会发现其实根本没法达到这么高的复杂度。因为对于一个区间得到的点也就是最多n个,我们均摊过来,每个平衡树内的结点并不会多,若结点多,那么对应区间需要扫的平衡树结点个数也就会相应减少。 数组不要开小了,血的教训,线段树一定空间要开两倍!!!
code:
//by newuser
#include<stdio.h>
#include<bits/stdc++.h>
#define zig(x) zigzag(x,1)
#define zag(x) zigzag(x,2)
using namespace std;
const int maxn = 50005;
const int spmaxn = 2500005;
int n,m,MAX;
int A[maxn];
int sta[maxn],statop; 
namespace spa
{
    int rt[maxn*2],cnt[spmaxn],ls[spmaxn],rs[spmaxn],dat[spmaxn],siz[spmaxn],fa[spmaxn],tot;
    inline void  putup(int x) { siz[x] = siz[ls[x]] + siz[rs[x]] + cnt[x]; }
	inline void zigzag(int x,int knd)
    {
    int y=fa[x],z=fa[y];
    if(z)
    {
    if(ls[z]==y) ls[z]=x;
    else rs[z]=x;
	}
    fa[x]=z; fa[y]=x;
    if(knd==1)
    {
        ls[y]=rs[x];
        fa[ls[y]]=y;
        rs[x]=y;
    }
    else
    {
        rs[y]=ls[x];
        fa[rs[y]]=y;
        ls[x]=y;
    }
    putup(y); putup(x);
    }
    inline void splay(int x,int root)
    {
    int y,z;
    while(fa[x])
    {
        y=fa[x]; z=fa[y];
        if(z)
        {
            if(ls[z]==y)
            {
                if(ls[y]==x) { zig(y); zig(x); }
                else { zag(x); zig(x); }
            }
            else
            {
                if(rs[y]==x) { zag(y); zag(x); }
                else { zig(x); zag(x); }
            }
        }
        else
        {
            if(ls[y]==x) zig(x);
            else zag(x);
        }
    }
    rt[root]=x;
    }
    inline void insert(int x,int root)
    {
        if(!rt[root]) { rt[root] = ++tot; cnt[tot]=siz[tot]=1; dat[tot]=x; ls[tot]=rs[tot]=fa[tot]=0; return ; }
        ++tot; int p = rt[root];
        while(p)
        {
            siz[p]++;
            if(x<dat[p])
            {
                if(!ls[p])
                {
                    ls[p]=tot;
                    break;
                }
                p=ls[p];
            }
            else if(x>dat[p])
            {
                if(!rs[p])
                {
                    rs[p]=tot;
                    break;
                }
                p=rs[p];
            }
            else { --tot; 	cnt[p]++; return; }
        }
        dat[tot]=x; cnt[tot]=siz[tot]=1;
        fa[tot] = p; ls[tot]=rs[tot]=0;
        splay(tot,root);
    }
    inline int getmax(int p)
    {
        while(rs[p]) p = rs[p];
        return p;
    }
    inline void del(int x,int root)
    {
    	splay(x,root);
        --cnt[x];
        siz[x]--;
        if(cnt[x]) return;
        int ll=ls[x]; int rr=rs[x];
        fa[ll]=fa[rr]=ls[x]=rs[x]=siz[x]=0;
        if(!ll) { rt[root]=rr; return; }
        if(!rr) { rt[root]=ll; return; }
        ll=getmax(ll);
        splay(ll,root);
        fa[ll]=0;
        fa[rr]=ll; rs[ll]=rr;
        putup(rr); putup(ll); 
        return;
    }
    inline int fi(int x,int root)
    {
        int p=rt[root];
        while(p)
        {
            if(x==dat[p]) return p;
            else if(x<dat[p]) p = ls[p];
            else if(x>dat[p]) p = rs[p];
        }
        splay(p,root);
        return p;
    }
    inline int getsmaller(int x,int root)
    {
        int p = rt[root],ans=0;
        while(p)
        {
            if(dat[p]>x) p = ls[p];
            else if(dat[p]<x)
            {
                ans+=siz[ls[p]]+cnt[p];
                p = rs[p];
            }
            else 
            {
                ans+=siz[ls[p]];
                break;
            }
        }
        if(p) splay(p,root);
        return ans;
    }
    inline int getqianqu(int x,int root)
    {
        int p = rt[root],ans=-0x3f3f3f3f;
        while(p)
        {
            if(dat[p]>=x)
            {
                p = ls[p];
            }
            else
            {
                if(dat[p]>ans) ans = dat[p];
                p = rs[p];
            } 
        }
        if(p) splay(p,root);
        return ans;
    }
    inline int gethouji(int x,int root)
    {
        int p = rt[root],ans=0x3f3f3f3f;
        while(p)
        {
            if(dat[p]<=x) 
            {
                p=rs[p];
            }
            else
            {
                if(dat[p]<ans) ans = dat[p];
                p = ls[p];
            }
        }
        if(p) splay(p,root);
        return ans;
    }
}
namespace seg
{
    int ls[maxn*2],rs[maxn*2],dy[maxn*2],tot;
    int maketree(int l,int r)
    {
        int p = ++tot;
        if(l<r)
        {
            int mid = (l+r)>>1;
            ls[p] = maketree(l,mid);
            rs[p] = maketree(mid+1,r);
        }
        for(int i=l;i<=r;i++) spa::insert(A[i],p);
        return p;
    }
    inline void getqujian(int p,int l,int r,int x,int y)
    {
        if(x<=l&&r<=y)
        {
            sta[++statop] = p;
            return;
        }
        int mid = (l+r)>>1;
        if(x>mid) getqujian(rs[p],mid+1,r,x,y);
        else if(y<=mid) getqujian(ls[p],l,mid,x,y);
        else getqujian(ls[p],l,mid,x,y),getqujian(rs[p],mid+1,r,x,y);
    }
    inline void getbaohan(int p,int l,int r,int k)
    {
    	sta[++statop] = p;
    	if(l<r)
    	{
    		int mid = (l+r)>>1;
    		if(k<=mid) getbaohan(ls[p],l,mid,k);
    		else getbaohan(rs[p],mid+1,r,k);
    	}
    	return;
    }
} 
inline int solve1()
{
    int l,r,k;
    scanf("%d%d%d",&l,&r,&k);
    statop=0;
    seg::getqujian(1,1,n,l,r);
    int ans = 1;
    for(int i=1;i<=statop;i++)
    {
        ans+=spa::getsmaller(k,sta[i]);
    }
    return ans;
}
inline int ggg(int mid)
{
    int tmp=0;
    for(int i=1;i<=statop;i++)
    {
        tmp+=spa::getsmaller(mid,sta[i]);
    }
    return tmp;
}
inline int solve2()
{
    int l,r,k;
    scanf("%d%d%d",&l,&r,&k);
    int L=0;int R=MAX; int mid;
    statop=0;
    seg::getqujian(1,1,n,l,r);
    while(L<=R)
    {
        mid = (L+R)>>1;
        int aha = ggg(mid);
        if(aha<k) L = mid+1;
        else R=mid-1;
    }
    return L-1;
}
inline void solve3()
{
    int pos,k;
    scanf("%d%d",&pos,&k);
    statop=0; 
    seg::getbaohan(1,1,n,pos);
    for(int i=1;i<=statop;i++)
    {    	int x = spa::fi(A[pos],sta[i]);
    	spa::del(x,sta[i]);
    	spa::insert(k,sta[i]);
    }
    A[pos] = k;
    MAX = max(MAX,k);
}
inline int solve4()
{
    int l,r,k;
    scanf("%d%d%d",&l,&r,&k);
    statop=0;
    seg::getqujian(1,1,n,l,r);
    int ans = -0x3f3f3f3f;
    for(int i=1;i<=statop;i++)
    {
        ans = max(ans,spa::getqianqu(k,sta[i]));
    }
    return ans;
}
inline int solve5()
{
    int l,r,k;
    scanf("%d%d%d",&l,&r,&k);
    statop=0;
    seg::getqujian(1,1,n,l,r);
    int ans = 0x3f3f3f3f;
    for(int i=1;i<=statop;i++)
    {
        ans = min(ans,spa::gethouji(k,sta[i]));
    }
    return ans;
}
int main()
{

    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%d",&A[i]),MAX=max(MAX,A[i]);
    seg::maketree(1,n);
    for(int i=1;i<=m;i++)
    {
        int opt;
        scanf("%d",&opt);
        if(opt==1)  printf("%d\n",solve1());
        else if(opt==2) printf("%d\n",solve2());
        else if(opt==3) solve3();
        else if(opt==4) printf("%d\n",solve4());
        else printf("%d\n",solve5());
    }
}  
 
posted @ 2018-06-22 23:30  Newuser233  阅读(4)  评论(0)    收藏  举报