【51Nod1766】树上的最远点对-线段树+树的直径

测试地址:树上的最远点对
做法:本题需要用到线段树+树的直径。
我们直觉上感觉这题的区间询问需要用到线段树,那么我们就要发掘出问题中隐藏的可合并的性质。而看见最远点对又会想到树的直径,那么树的直径这个信息是不是可以快速合并的呢?怎么合并呢?下面我们就夯实理论基础,这样我们就能很简单地解决这一题了。
首先有一个结论:树上任意一个点在树中的最远点是树的直径的某个端点。我们可以用反证法轻易地证明这一点。再扩展一下,有以下结论:树上任意一个点在树中的一个点集中的最远点是该点集中最长链的一个端点。其实我们把点集等价地看为一棵虚树,然后就能用相似的证法解决了。
那么现在的问题是,我们有两个点集,要求两个点集之间的最长链,上面的结论对解决这个问题有什么帮助呢?首先对一个集合用上面的结论,于是就知道一个端点在另一个集合的最长链的两个端点中,再反过来用上面的结论,我们就能得到最终的结论:两个点集之间的最长链的两个端点,分别是两个点集中各自的最长链的某个端点。
显然地,如果把两个点集合并起来,那么合并后点集的最长链,就等同于两个点集各自的最长链,和两个点集之间的最长链之中的最大值。因此,利用欧拉序+ST表的LCA算法,我们可以在O(1)的时间内合并这个信息(这题时限很紧,用倍增会T),于是就可以用线段树做了,那么每次询问一个区间的最长链就是O(logn)的了,而最后的询问也恰好就是问两个点集之间的最长链,直接用结论即可。这样我们就以O(nlogn)的时间复杂度解决了此题。
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m,first[100010]={0},tot=0;
int st[200010],top=0,dep[100010],pos[100010];
int mnp[200010][21],p[200010];
int seg[400010][2];
ll dis[100010];
struct edge
{
    int v,next;
    ll d;
}e[200010];

void insert(int a,int b,ll d)
{
    e[++tot].v=b;
    e[tot].d=d;
    e[tot].next=first[a];
    first[a]=tot;
}

void dfs(int v,int fa)
{
    st[++top]=v;
    pos[v]=top;

    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=fa)
        {
            dep[e[i].v]=dep[v]+1;
            dis[e[i].v]=dis[v]+e[i].d;
            dfs(e[i].v,v);
            st[++top]=v;
        }
}

void rmq_init()
{
    for(int i=1;i<=top;i++)
        mnp[i][0]=st[i];
    for(int i=1;i<=20;i++)
        for(int j=1;j+(1<<i)-1<=top;j++)
        {
            if (dep[mnp[j][i-1]]<dep[mnp[j+(1<<(i-1))][i-1]])
                mnp[j][i]=mnp[j][i-1];
            else mnp[j][i]=mnp[j+(1<<(i-1))][i-1];
        }
    p[1]=0;
    for(int i=2;i<=top;i++)
    {
        if (1<<(p[i-1]+1)<i) p[i]=p[i-1]+1;
        else p[i]=p[i-1];
    }
}

int rmq(int l,int r)
{
    int x=p[r-l+1];
    if (dep[mnp[l][x]]<dep[mnp[r-(1<<x)+1][x]])
        return mnp[l][x];
    else return mnp[r-(1<<x)+1][x];
}

ll dist(int a,int b)
{
    int g=rmq(min(pos[a],pos[b]),max(pos[a],pos[b]));
    return dis[a]+dis[b]-2ll*dis[g];
}

void merge(int a1,int a2,int b1,int b2,int &s1,int &s2)
{
    ll ans=0;
    if (dist(a1,a2)>ans) ans=dist(a1,a2),s1=a1,s2=a2;
    if (dist(a1,b1)>ans) ans=dist(a1,b1),s1=a1,s2=b1;
    if (dist(a1,b2)>ans) ans=dist(a1,b2),s1=a1,s2=b2;
    if (dist(a2,b1)>ans) ans=dist(a2,b1),s1=a2,s2=b1;
    if (dist(a2,b2)>ans) ans=dist(a2,b2),s1=a2,s2=b2;
    if (dist(b1,b2)>ans) ans=dist(b1,b2),s1=b1,s2=b2;
}

void pushup(int no)
{
    merge(seg[no<<1][0],seg[no<<1][1],seg[no<<1|1][0],seg[no<<1|1][1],seg[no][0],seg[no][1]);
}

void buildtree(int no,int l,int r)
{
    if (l==r)
    {
        seg[no][0]=seg[no][1]=l;
        return;
    }
    int mid=(l+r)>>1;
    buildtree(no<<1,l,mid);
    buildtree(no<<1|1,mid+1,r);
    pushup(no);
}

void query(int no,int l,int r,int s,int t,int &ans1,int &ans2)
{
    if (l>=s&&r<=t)
    {
        merge(ans1,ans2,seg[no][0],seg[no][1],ans1,ans2);
        return;
    }
    int mid=(l+r)>>1;
    if (s<=mid) query(no<<1,l,mid,s,t,ans1,ans2);
    if (t>mid) query(no<<1|1,mid+1,r,s,t,ans1,ans2);
}

int main()
{
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int x,y;
        ll z;
        scanf("%d%d%lld",&x,&y,&z);
        insert(x,y,z),insert(y,x,z);
    }

    top=0;
    dep[1]=dis[1]=0;
    dfs(1,0);
    rmq_init();

    buildtree(1,1,n);
    scanf("%d",&m);
    for(int i=1;i<=m;i++)
    {
        int a,b,c,d;
        scanf("%d%d%d%d",&a,&b,&c,&d);
        int a1=a,a2=a,b1=c,b2=c;
        ll ans=0;
        query(1,1,n,a,b,a1,a2);
        query(1,1,n,c,d,b1,b2);
        ans=max(ans,dist(a1,b1));
        ans=max(ans,dist(a1,b2));
        ans=max(ans,dist(a2,b1));
        ans=max(ans,dist(a2,b2));
        printf("%lld\n",ans);
    }

    return 0;
}
posted @ 2018-08-06 10:59  Maxwei_wzj  阅读(133)  评论(0编辑  收藏  举报