BZOJ 1146 二分+链剖+线段树+treap

思路:
恶心的数据结构题……

首先 我们 链剖 把树 变成序列 再 套一个 区间 第K大就好了……

复杂度(n*log^4n)

//By SiriusRen
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 88888
#define inf 100000000
int n,q,first[N],next[N*2],v[N*2],t[N],tot,op,xx,yy;
int fa[N],son[N],deep[N],top[N],siz[N],cnt,ch[N];
int root[N*40],size;
struct Treap{int ch[2],sz,cnt,v,rnd;}tr[N*40];
void Upd(int k){tr[k].sz=tr[tr[k].ch[0]].sz+tr[tr[k].ch[1]].sz+tr[k].cnt;}
void Rot(int &k,bool f){int t=tr[k].ch[f];tr[k].ch[f]=tr[t].ch[!f],tr[t].ch[!f]=k,Upd(k),Upd(t),k=t;}
void Insert(int &k,int num){
    if(!k){k=++size;tr[k].sz=tr[k].cnt=1,tr[k].rnd=rand();tr[k].v=num;return;}
    tr[k].sz++;
    if(tr[k].v==num){tr[k].cnt++;return;}
    bool f=num>tr[k].v;
    Insert(tr[k].ch[f],num);
    if(tr[tr[k].ch[f]].rnd<tr[k].rnd)Rot(k,f);
}
void Del(int &k,int num){
    if(tr[k].v==num){
        if(tr[k].cnt>1)tr[k].cnt--,tr[k].sz--;
        else if(tr[k].ch[0]*tr[k].ch[1]==0)k=max(tr[k].ch[0],tr[k].ch[1]);
        else Rot(k,tr[tr[k].ch[0]].rnd>tr[tr[k].ch[1]].rnd),Del(k,num);
    }
    else tr[k].sz--,Del(tr[k].ch[num>tr[k].v],num);
}
int get_rank(int k,int num){
    if(!k)return 0;
    if(tr[k].v==num)return tr[tr[k].ch[1]].sz;
    else if(tr[k].v<num)return get_rank(tr[k].ch[1],num);
    else return get_rank(tr[k].ch[0],num)+tr[tr[k].ch[1]].sz+tr[k].cnt;
}
void insert(int l,int r,int pos,int num,int wei){
    Insert(root[pos],wei);
    if(l==r)return;
    int mid=(l+r)>>1,lson=pos<<1,rson=pos<<1|1;
    if(mid<num)insert(mid+1,r,rson,num,wei);
    else insert(l,mid,lson,num,wei);
}
void change(int l,int r,int pos,int num,int wei){
    Del(root[pos],t[xx]),Insert(root[pos],wei);
    if(l==r)return;
    int mid=(l+r)>>1,lson=pos<<1,rson=pos<<1|1;
    if(mid<num)change(mid+1,r,rson,num,wei);
    else change(l,mid,lson,num,wei);
}
int query(int l,int r,int pos,int L,int R,int num){
    if(l>=L&&r<=R)return get_rank(root[pos],num);
    int mid=(l+r)>>1,lson=pos<<1,rson=pos<<1|1;
    if(mid<L)return query(mid+1,r,rson,L,R,num);
    else if(mid>=R)return query(l,mid,lson,L,R,num);
    else return query(l,mid,lson,L,R,num)+query(mid+1,r,rson,L,R,num);
}
void Add(int x,int y){v[tot]=y,next[tot]=first[x],first[x]=tot++;}
void add(int x,int y){Add(x,y),Add(y,x);}
void dfs(int x){
    siz[x]=1;
    for(int i=first[x];~i;i=next[i])
        if(v[i]!=fa[x]){
            fa[v[i]]=x,deep[v[i]]=deep[x]+1;
            dfs(v[i]),siz[x]+=siz[v[i]];
            if(siz[v[i]]>siz[son[x]])son[x]=v[i];
        }
}
void dfs2(int x,int tp){
    top[x]=tp,ch[x]=++cnt;
    insert(1,n,1,cnt,t[x]);
    if(son[x])dfs2(son[x],tp);
    for(int i=first[x];~i;i=next[i])
        if(v[i]!=fa[x]&&v[i]!=son[x])
            dfs2(v[i],v[i]);
}
int find(int x,int y,int num){
    int fx=top[x],fy=top[y],tmp=0;
    while(fx!=fy){
        if(deep[fx]<deep[fy])swap(fx,fy),swap(x,y);
        tmp+=query(1,n,1,ch[fx],ch[x],num);
        x=fa[fx],fx=top[x];
    }
    if(deep[x]>deep[y])swap(x,y);
    return tmp+query(1,n,1,ch[x],ch[y],num);
}
void b_srch(){
    int l=0,r=inf,ans;
    while(l<=r){
        int mid=(l+r)>>1;
        if(find(xx,yy,mid)>=op)l=mid+1;
        else ans=mid,r=mid-1;
    }
    if(!ans)puts("invalid request!");
    else printf("%d\n",ans);
}
int main(){
    memset(first,-1,sizeof(first));
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;i++)scanf("%d",&t[i]);
    for(int i=1;i<n;i++)scanf("%d%d",&xx,&yy),add(xx,yy);
    dfs(1),dfs2(1,1);
    for(int i=1;i<=q;i++){
        scanf("%d%d%d",&op,&xx,&yy);
        if(op)b_srch();
        else change(1,n,1,ch[xx],yy),t[xx]=yy;
    }
}

这里写图片描述

posted @ 2016-12-12 10:46  SiriusRen  阅读(154)  评论(0编辑  收藏  举报