hihoCoder #1381 : Little Y's Tree

http://hihocoder.com/problemset/problem/1381

 

一个结论:集合A的直径为a--b,集合B的直径为c--d,那么集合A∪B的直径必是以下6种之一:

a--b  c--d  a--c  a--d  b--c  b--d

 

断掉一条边,相当于从树的dfs序上取出一段区间

用线段树维护dfs序上任意一段区间的直径

如果[1,10]断掉的是[1,4] [3,4] [7,8]

答案就是[1,2]的直径+[3,4]的直径+[5,6]∪[9,10]的直径+[7,8]的直径

#include<cstdio>
#include<vector>
#include<iostream>
#include<algorithm>

using namespace std;

#define N 100001
typedef long long LL;

int n;
int tot,front[N],nxt[N<<1],to[N<<1],val[N<<1];

int lo2[N];

int dep[N],dy[N],fa[N][18];
LL dis[N];
int ll[N],rr[N];

vector<int>inc[N];
int st[N],top;

int a[N],bin[N],cnt;
LL ans;

struct node
{
    int a,b;
    LL dis;
}tr[N<<2];

void read(int &x)
{
    x=0; char c=getchar();
    while(!isdigit(c)) c=getchar();
    while(isdigit(c)) { x=x*10+c-'0'; c=getchar(); }
}

void add(int u,int v,int w)
{
    to[++tot]=v; nxt[tot]=front[u]; front[u]=tot; val[tot]=w;
    to[++tot]=u; nxt[tot]=front[v]; front[v]=tot; val[tot]=w;
}

void init()
{
    read(n);
    int u,v,w;
    for(int i=1;i<n;++i)
    {
        read(u); read(v); read(w);
        add(u,v,w);
    }
}

void dfs(int x)
{
    ll[x]=++tot;
    dy[tot]=x;
    int t;
    for(int i=front[x];i;i=nxt[i])
    {
        t=to[i];
        if(fa[t][0]) continue;
        fa[t][0]=x;
        dis[t]=dis[x]+val[i];
        dep[t]=dep[x]+1;
        dfs(t);
    }
    rr[x]=tot;
}

void pre()
{
    for(int i=2;i<=n;++i) lo2[i]=lo2[i>>1]+1; 
    fa[1][0]=-1;
    tot=0;
    dfs(1);
    int m=lo2[n]; 
    fa[1][0]=0; 
    for(int i=1;i<=n;++i)
        for(int j=1;j<=m;++j)
            fa[i][j]=fa[fa[i][j-1]][j-1];
            
}

int getlca(int x,int y)
{
    int m=lo2[dep[x]];
    for(int i=m;i>=0;--i)
        if(ll[fa[x][i]]>ll[y]) x=fa[x][i];
    return fa[x][0];
}

LL getdis(int x,int y)
{
    x=dy[x];
    y=dy[y];
    if(x==y) return 0;
    if(ll[x]<ll[y]) swap(x,y);
    int lca=getlca(x,y);
    if(lca==y) return dis[x]-dis[y];
    return dis[x]+dis[y]-dis[lca]*2;
}

node unionn(node p,node q)
{
    node t1=(node){p.a,q.a,getdis(p.a,q.a)};
    node t2=(node){p.a,q.b,getdis(p.a,q.b)};
    node t3=(node){p.b,q.a,getdis(p.b,q.a)};
    node t4=(node){p.b,q.b,getdis(p.b,q.b)};
    node t=p;
    if(q.dis>t.dis) t=q;
    if(t1.dis>t.dis) t=t1;
    if(t2.dis>t.dis) t=t2;
    if(t3.dis>t.dis) t=t3;
    if(t4.dis>t.dis) t=t4;
    return t;    
}

void build(int k,int l,int r)
{
    if(l==r)
    {
        tr[k].a=tr[k].b=l;
        return;
    }
    int mid=l+r>>1;
    build(k<<1,l,mid);
    build(k<<1|1,mid+1,r);
    tr[k]=unionn(tr[k<<1],tr[k<<1|1]);
}

bool cmp(int p,int q)
{
    return ll[p]<ll[q];
}

node query(int k,int l,int r,int opl,int opr)
{
    if(l>=opl && r<=opr) return tr[k];
    int mid=l+r>>1;
    if(opr<=mid) return query(k<<1,l,mid,opl,opr);
    if(opl>mid) return query(k<<1|1,mid+1,r,opl,opr);
    node tmp1=query(k<<1,l,mid,opl,opr);
    node tmp2=query(k<<1|1,mid+1,r,opl,opr);
    return unionn(tmp1,tmp2);    
}

void dfs2(int x)
{
    int l=ll[x],m=inc[x].size(),t;
    node tmp,mx;
    mx.a=mx.b=ll[x];
    mx.dis=0;
    for(int i=0;i<m;++i)
    {
        t=inc[x][i];
        if(ll[t]!=l)
        {
            tmp=query(1,1,n,l,ll[t]-1);
            mx=unionn(mx,tmp);
        }
        dfs2(t);
        l=rr[t]+1;
    }
    if(l!=rr[x]+1) 
    {
        tmp=query(1,1,n,l,rr[x]);
        mx=unionn(mx,tmp);
    }
    ans+=mx.dis;    
}

void solve()
{
    int m,k,x,y;
    read(m);
    while(m--)
    {
        for(int i=1;i<=cnt;++i) inc[bin[i]].clear();
        cnt=ans=0;
        st[top=1]=1;
        read(k);
        for(int i=1;i<=k;++i) 
        {
            read(x);
            x<<=1;
            if(ll[to[x]]<ll[to[x-1]]) y=to[x-1];
            else y=to[x];
            a[i]=y;
        }
        sort(a+1,a+k+1,cmp);
        for(int i=1;i<=k;++i)
        {
            y=a[i];
            while(!(ll[y]>=ll[st[top]] && rr[y]<=rr[st[top]])) top--;
            inc[st[top]].push_back(y);
            bin[++cnt]=st[top];
            st[++top]=y; 
        }
        dfs2(1);
        cout<<ans<<'\n';
    }
}

int main()
{
    init();
    pre();
    build(1,1,n);
    solve();
} 

 

时间限制:24000ms
单点时限:4000ms
内存限制:512MB

描述

小Y有一棵n个节点的树,每条边都有正的边权。

小J有q个询问,每次小J会删掉这个树中的k条边,这棵树被分成k+1个连通块。小J想知道每个连通块中最远点对距离的和。

这里的询问是互相独立的,即每次都是在小Y的原树上进行操作。

输入

第一行一个整数n,接下来n-1行每行三个整数u,v,w,其中第i行表示第i条边边权为wi,连接了ui,vi两点。

接下来一行一个整数q,表示有q组询问。

对于每组询问,第一行一个正整数k,接下来一行k个不同的1到n-1之间的整数,表示删除的边的编号。

1<=n,q,Σk<=105, 1<=w<=109

输出

共q行,每行一个整数表示询问的答案。

样例输入
5
1 2 2
2 3 3
2 4 4
4 5 2
3
4 1 2 3 4
1 4
2 2 3
样例输出
0
7
4
posted @ 2019-11-18 12:59  TRTTG  阅读(252)  评论(0编辑  收藏  举报