【树论】两棵树

WC2018 T1 简化版 by OBlack.
有两棵有 n​ 个节点的树,分别为 A,B,树上每条边都有一个权值 v_i​ 令 disA(x,y) 和 disB(x,y) 分别表示 x 节点与 y 节点在树 A,B 上的距离 请你找出一个点对 (x,y) 使得 disA(x,y)+disB(x,y) 取得最大值,注意 x!=y
输入格式
第一行,一个整数 n 接下来 n-1 行,每行三个整数 x_i,y_i,v_i ,表示 A 上的一条边 接下来 n-1 行,每行三个整数 x_i,y_i,v_i ,表示 B 上的一条边
输出格式
输出一个整数 Ans ,表示最大值
样例输入
5 4 5 6 4 1 2 5 2 1 4 3 0 4 1 9 4 2 0 1 5 0 5 3 7
样例输出
23
提示
对于 30% 的数据, n<=2000,v_i<=10^6 对于所有数据, n<=10^5, v_i<=10^12 本题为  WC2018 T1 (NKOJ  5004) 简化版

题解

我们将dis1(x,y)+dis2(x,y)变一下形 --> dep1[x] + dep1[y] - 2*dep1[lca1(x,y)] + dep2[x] + dep2[y] - 2*dep2[lca2(x,y)] 我们发现对于任何x,他的dep1[x]和dep2[x]在式子中总是成对出现的。那么我们将他们合并在一起考虑-->方法是将第二棵树中的x,每个对应在其下连一条dep1[x]的边,并设置新节点为x'。同时我们在第一棵树中枚举lca,那么问题就变成了--->对于第一棵树的lca,对于他的分别两个子树合并的时候的两个点集,分别对应在第二棵树中的两个离散的点集,在其中找一条最长的路径。也就是比如对于x的son1对应一些点集,son2对应一些点集,(因为在这两个点集中各选一个保证他们在第一棵树树上的lca一定为x)然后两个中各自挑选一个的最长路径。 其实对于这样两个离散点集中,分别各自对应着两个直径。容易证明。对于两个离散点集,他们之间的最长路径的端点一定是分别在他们的直径4个点上。同时,也可以证明,对于两个离散点集他们合并之后的那个点集的直径的端点,也一定是两个离散点集的四个端点之二。 这样我们只需要变换一下第二颗树,然后在第一棵树上dfs合并子树,在第二颗树上查询最长路径更新答案(LCA次数比较多,考虑O(1)LCA),就可以了。。
/*
在第一棵树上枚举lca,由于式子
disa(x,y) + disb(x,y) == dep1x + dep1y -2*lca1 + dep2x + dep2y - 2*lca2
可以在第二棵树上外挂一个节点,那么就是对于第二棵树上的一个点集内的求点集了
然后直接直径合并就可以了 
*/
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cmath>
#define pr pair<int,int>
#define fi first
#define se second
using namespace std;
typedef long long ll;
const int maxn = 1000005;
const int maxm = 2000005;
ll len1[maxm];
int la1[maxn],nt1[maxm],en1[maxm],owo1;
void addedge1(int a,int b,ll c) {
    en1[++owo1] = b; nt1[owo1] = la1[a];
    la1[a] = owo1; len1[owo1] = c;
}
ll len2[maxm];
int la2[maxn],nt2[maxm],en2[maxm],owo2;
void addedge2(int a,int b,ll c) {
    en2[++owo2] = b; nt2[owo2] = la2[a];
    la2[a] = owo2; len2[owo2] = c;
}
ll dep1[maxn],dep2[maxn];
int n;
int ST[23][maxn*2],lac[maxn],oula;
char buf[1<<20],*p1,*p2;
#define GC (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin),p1==p2)?0:*p1++)
//#define GC getchar()
inline ll R()
{
    char t=GC;
    ll x=0;
    while(!isdigit(t)) t=GC;
    while(isdigit(t)) x=x*10+t-48,t=GC;
    return x;
}
void dfs1(int x,int ba) {
    for(int it=la1[x];it;it=nt1[it]) {
        int y = en1[it];
        if(y==ba) continue; 
        dep1[y] = dep1[x] + len1[it];
        dfs1(y,x);
    }
}
int logg[maxn*2];
int S;
void dfs2(int x,int ba) {
    ST[0][++oula] = x;
    for(int it=la2[x];it;it=nt2[it]) {
        int y = en2[it];
        if(y==ba) continue;
        dep2[y] = dep2[x] + len2[it];
        dfs2(y,x);
        ST[0][++oula] = x;
    }
    lac[x] = oula;
}
ll ans = -1e18;
int getlca(int x,int y) {
    x = lac[x]; y = lac[y];
    if(x>y) swap(x,y);
    int k = logg[y-x+1];
    return dep2[ST[k][x]] < dep2[ST[k][y-(1<<k)+1]] ? ST[k][x] : ST[k][y-(1<<k)+1];
}
ll getcd(int x,int y) {
    int lcc = getlca(x,y);
    return dep2[x] + dep2[y] - 2*dep2[lcc];
}
pr merge(pr a,pr b) {
    pr now = a; ll ccd = getcd(a.fi,a.se);
    ll tmp = getcd(b.fi,b.se);
    if(tmp>ccd) {
        now = b; ccd = tmp;
    }
    
    tmp = getcd(a.fi,b.fi);
    if(tmp>ccd) {
        now = pr(a.fi,b.fi); ccd = tmp;
    }
    
    tmp = getcd(a.fi,b.se);
    if(tmp>ccd) {
        now = pr(a.fi,b.se); ccd = tmp;
    }
    
    tmp = getcd(a.se,b.fi);
    if(tmp>ccd) {
        now = pr(a.se,b.fi); ccd = tmp;
    }
    
    tmp = getcd(a.se,b.se);
    if(tmp>ccd) {
        now = pr(a.se,b.se); ccd = tmp;
    }
    return now;
}
pr dfs3(int x,int ba) {
    pr zj = pr(x+n,x+n);
    for(int it=la1[x];it;it=nt1[it]) {
        int y = en1[it];
        if(y==ba) continue;
        pr tmp = dfs3(y,x);
        ans = max(getcd(tmp.fi,zj.fi)-2*dep1[x],ans);
        ans = max(getcd(tmp.fi,zj.se)-2*dep1[x],ans);
        ans = max(getcd(tmp.se,zj.fi)-2*dep1[x],ans);
        ans = max(getcd(tmp.se,zj.se)-2*dep1[x],ans);
        zj = merge(zj,tmp);
    }
    return zj;
}
main() {
   // freopen("aha.in","r",stdin);
   // freopen("aha.out","w",stdout);
    n=R();
    for(int i=1;i<n;i++) {
        int x,y;ll z; x=R(); y=R(); z=R();
        addedge1(x,y,z); addedge1(y,x,z);
    }
    dfs1(1,0);
    for(int i=1;i<=n;i++) {
        addedge2(i,i+n,dep1[i]);
        addedge2(i+n,i,dep1[i]);
    }
    for(int i=1;i<n;i++) {
        int x,y;ll z; x=R();y=R();z=R();
        addedge2(x,y,z); addedge2(y,x,z);
    }
    dfs2(1,0);
    logg[0]=logg[1]=0;
    for(int i=2;i<=oula;i++) logg[i] = logg[i>>1] + 1;
    for(int s=1;s<=logg[oula];s++) {
        for(int i=1;i+(1<<s)-1<=oula;i++) {
            ST[s][i] = dep2[ST[s-1][i]] < dep2[ST[s-1][i+(1<<(s-1))]] ? ST[s-1][i] : ST[s-1][i+(1<<(s-1))];
        }
    }
    dfs3(1,0);
    printf("%lld",ans);
}
 
posted @ 2018-10-16 19:44  Newuser233  阅读(8)  评论(0)    收藏  举报