【BZOJ4539】树(HNOI2016)-主席树+LCA

测试地址:
做法:本题需要用到主席树+LCA。
要求两点间的距离,显然要维护每个点的深度,以及要求两个点的LCA。
我们把一开始的树看成一块,然后每次操作,都是在某一块下面挂一个新的块,每个块都是模板树的一棵子树。这样我们可以先把块缩成点,那么缩块后整棵大树就变成了一棵更小的树。考虑求一个点的深度,这个深度等于它到它所在块的根的距离,加上块根到整棵树根节点的距离,显然前面的部分可以直接在模板树上求出,处理出深度后就可以O(1)询问,那我们又要维护块根在模板树中对应的点编号,可以直接在构造过程中维护,而后面的部分就可以直接在构造时维护。
求深度的问题解决了,现在要解决求LCA的问题了。注意到,两个点的LCA一定在它们所属块在缩块树上的LCA所对应的块中,于是我们倍增跳到LCA的下方,这时候我们需要知道这棵树挂在了哪个点上才能进入到LCA块中,于是我们在构造时维护块根上面的点在模板树中对应的点编号。那么最后我们就得到了在同一块中的两个点,直接在模板树中倍增求出LCA即可。这样一次询问的时间复杂度就是常数稍大的O(logn)了。注意某些特殊情况,例如一开始两个点就在同一块中,或者其中一块是另一块的祖先。
上面这一通操作看上去完美,但实际上还有一个问题:因为点数可能达到1010,不能直接存储,那我们如何快速定位一个点在哪个块中,并且它在模板树中对应哪个点呢?注意到每块中节点的编号是连续的,因此我们可以二分定位该点所在块,时间复杂度为O(logn),而每一块中节点的编号顺序和模板树中编号顺序相同,因此要求该块中第k个编号的点在模板树中对应的点编号,就是求在模板树的对应子树中第k小的编号,我们知道树上的子树第k小可以转化为DFS序上的区间第k小,这就是主席树的经典应用了,于是我们做到了一次定位O(logn)的时间复杂度。
那么我们就完成了这一题,时间复杂度为O(nlogn)
我傻逼的地方:漫长的四个小时告诉我们,永远都不要用相似的名字命名不同的东西……
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,q,first[100010]={0},tot=0,fa[100010][21]={0};
int in[100010],out[100010],tim=0,pos[100010];
int rt[100010]={0},seg[2000010]={0},ch[2000010][2]={0};
int blocktop[100010],blockfa[100010][21]={0};
int blockup[100010]={0},blockdepth[100010];
ll dep[100010],siz[100010],blockl[100010],blockdep[100010],totsiz;
struct edge
{
    int v,next;
}e[200010];

void insert(int a,int b)
{
    e[++tot].v=b;
    e[tot].next=first[a];
    first[a]=tot;
}

void dfs(int v)
{
    in[v]=++tim;
    pos[tim]=v;
    siz[v]=1;
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=fa[v][0])
        {
            dep[e[i].v]=dep[v]+1;
            fa[e[i].v][0]=v;
            dfs(e[i].v);
            siz[v]+=siz[e[i].v];
        }
    out[v]=tim;
}

void buildtree(int &v,int l,int r)
{
    v=++tot;
    if (l==r) return;
    int mid=(l+r)>>1;
    buildtree(ch[v][0],l,mid);
    buildtree(ch[v][1],mid+1,r);
}

void add(int &v,int last,int l,int r,int x)
{
    v=++tot;
    seg[v]=seg[last];
    ch[v][0]=ch[last][0];
    ch[v][1]=ch[last][1];
    if (l==r)
    {
        seg[v]++;
        return;
    }
    int mid=(l+r)>>1;
    if (x<=mid) add(ch[v][0],ch[last][0],l,mid,x);
    else add(ch[v][1],ch[last][1],mid+1,r,x);
    seg[v]=seg[ch[v][0]]+seg[ch[v][1]];
}

int findkth(int v,int last,int l,int r,int k)
{
    if (l==r) return l;
    int mid=(l+r)>>1;
    if (seg[ch[v][0]]-seg[ch[last][0]]<k)
    {
        k-=seg[ch[v][0]]-seg[ch[last][0]];
        return findkth(ch[v][1],ch[last][1],mid+1,r,k);
    }
    else return findkth(ch[v][0],ch[last][0],l,mid,k);
}

void init()
{
    scanf("%d%d%d",&n,&m,&q);
    for(int i=1;i<n;i++)
    {
        int a,b;
        scanf("%d%d",&a,&b);
        insert(a,b),insert(b,a);
    }

    dep[0]=-1;
    dep[1]=0;
    dfs(1);
    for(int i=1;i<=18;i++)
        for(int j=1;j<=n;j++)
            fa[j][i]=fa[fa[j][i-1]][i-1];

    tot=0;
    buildtree(rt[0],1,n);
    for(int i=1;i<=n;i++)
        add(rt[i],rt[i-1],1,n,pos[i]);
}

int findblock(ll v,int limit)
{
    int l=1,r=limit;
    while(l<r)
    {
        int mid=(l+r)>>1;
        if (v>=blockl[mid+1]) l=mid+1;
        else r=mid;
    }
    return l;
}

int findpoint(int now,int k)
{
    int p=blocktop[now];
    return findkth(rt[out[p]],rt[in[p]-1],1,n,k);
}

void buildblock()
{
    totsiz=n;
    blockfa[1][0]=0;
    blockdep[1]=0;
    blocktop[1]=1;
    blockl[1]=1;
    blockup[1]=0;
    blockdepth[1]=0;
    blockdepth[0]=-1;

    for(int i=2;i<=m+1;i++)
    {
        int a,now,p;
        ll b,depth;
        scanf("%d%lld",&a,&b);

        now=findblock(b,i-1);
        p=findpoint(now,b-blockl[now]+1);
        depth=blockdep[now]+dep[p]-dep[blocktop[now]];

        blockfa[i][0]=now;
        blockdep[i]=depth+1;
        blocktop[i]=a;
        blockup[i]=p;
        blockl[i]=totsiz+1;
        blockdepth[i]=blockdepth[now]+1;
        totsiz+=siz[a];
    }

    for(int i=1;i<=18;i++)
        for(int j=1;j<=m+1;j++)
            blockfa[j][i]=blockfa[blockfa[j][i-1]][i-1];
}

int lca(int a,int b)
{
    if (dep[a]<dep[b]) swap(a,b);
    for(int i=18;i>=0;i--)
        if (dep[fa[a][i]]>=dep[b]) a=fa[a][i];
    if (a==b) return a;
    for(int i=18;i>=0;i--)
        if (fa[a][i]!=fa[b][i])
            a=fa[a][i],b=fa[b][i];
    return fa[a][0];
}

void work()
{
    for(int i=1;i<=q;i++)
    {
        ll a,b,ans=0;
        int nowa,nowb,pa,pb;
        int ansa,ansb,ansc;
        scanf("%lld%lld",&a,&b);
        nowa=findblock(a,m+1),nowb=findblock(b,m+1);

        if (blockdepth[nowa]<blockdepth[nowb])
            swap(nowa,nowb),swap(a,b);
        int x=nowa,y=nowb;
        for(int j=18;j>=0;j--)
            if (blockdepth[blockfa[x][j]]>blockdepth[y])
                x=blockfa[x][j];

        if (x==y||blockfa[x][0]==y)
        {
            if (x==y) pa=findpoint(nowa,a-blockl[nowa]+1);
            else pa=blockup[x];
            pb=findpoint(nowb,b-blockl[nowb]+1);
            if (x!=y) x=blockfa[x][0];
        }
        else
        {
            if (blockdepth[x]>blockdepth[y]) x=blockfa[x][0];
            for(int j=18;j>=0;j--)
                if (blockfa[x][j]!=blockfa[y][j])
                    x=blockfa[x][j],y=blockfa[y][j];
            pa=blockup[x];
            pb=blockup[y];
            x=blockfa[x][0];
        }

        ansa=findpoint(nowa,a-blockl[nowa]+1);
        ansb=findpoint(nowb,b-blockl[nowb]+1);
        ansc=lca(pa,pb);

        ans+=blockdep[nowa]+dep[ansa]-dep[blocktop[nowa]];
        ans+=blockdep[nowb]+dep[ansb]-dep[blocktop[nowb]];
        ans-=(blockdep[x]+dep[ansc]-dep[blocktop[x]])<<1;
        printf("%lld\n",ans);
    }
}

int main()
{
    init();
    buildblock();
    work();

    return 0;
}
posted @ 2018-05-29 15:49  Maxwei_wzj  阅读(118)  评论(0编辑  收藏  举报